diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 000000000..24bdaebc2 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,3 @@ +{ + "plugins": ["prettier-plugin-organize-imports"] +} \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index dba7ab2d4..45f2eb057 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -5,6 +5,7 @@ **vscode-dbt-power-user** is a comprehensive VSCode extension that makes VSCode seamlessly work with dbt (data build tool). It's an open-source project published by Altimate AI that extends VSCode with advanced dbt features including auto-completion, query preview, lineage visualization, documentation generation, and AI-powered features. ### Key Statistics + - **Version**: 0.57.3 - **Project Type**: VSCode Extension (TypeScript/React) - **License**: MIT @@ -25,16 +26,19 @@ The extension follows a **dependency injection pattern** using Inversify contain The extension operates across multiple processes: 1. **Main Extension Process** (Node.js/TypeScript) + - VSCode API integration - File system operations - dbt CLI interactions 2. **Webview Panels** (React/TypeScript) + - Modern React-based UI components - Located in `webview_panels/` directory - Built with Vite, uses Antd for UI components 3. **Python Bridge Integration** + - dbt core/cloud integration via Python scripts - Key files: `dbt_core_integration.py`, `dbt_cloud_integration.py` - Jupyter kernel for notebook functionality @@ -62,12 +66,14 @@ src/ ### 1. dbt Integration Support **Multiple Integration Types**: + - **dbt Core**: Direct Python integration via Python bridge - **dbt Cloud**: API-based integration with dbt Cloud services - **dbt Fusion**: Command-line integration with dbt-fusion CLI - **Core Command**: CLI wrapper integration for dbt core **Key Integration Files**: + - `src/dbt_client/dbtCoreIntegration.ts` - dbt Core Python integration - `src/dbt_client/dbtCloudIntegration.ts` - dbt Cloud API integration - `src/dbt_client/dbtFusionCommandIntegration.ts` - dbt Fusion CLI integration @@ -76,6 +82,7 @@ src/ ### 2. Language Server Features **Provider Architecture**: Each feature implemented as a separate provider: + - `autocompletion_provider/` - IntelliSense for dbt models, macros, sources - `definition_provider/` - Go-to-definition functionality - `hover_provider/` - Hover information @@ -85,12 +92,14 @@ src/ ### 3. Webview Panel System **Modern React Architecture** (`webview_panels/`): + - **Build System**: Vite + TypeScript + React 18 - **State Management**: Redux Toolkit - **UI Framework**: Antd + custom components - **Data Visualization**: Perspective.js, Plotly.js **Key Panels**: + - `modules/dataPilot/` - AI chat interface - `modules/queryPanel/` - Query results and analysis - `modules/lineage/` - Data lineage visualization @@ -100,12 +109,14 @@ src/ ### 4. AI and Advanced Features **DataPilot AI Integration**: + - Chat-based interface for dbt assistance - Query explanation and optimization - Documentation generation - Test suggestions **MCP Server Integration**: + - Tool calling for dbt operations - Integration with Claude and other AI models - Located in `src/mcp/server.ts` @@ -115,12 +126,14 @@ src/ ### 1. Multi-Stage Build Process **Main Extension Build** (Webpack): + ```bash npm run webpack # Development build npm run vscode:prepublish # Production build ``` **Webview Panels Build** (Vite): + ```bash npm run panel:webviews # Build React components ``` @@ -128,12 +141,15 @@ npm run panel:webviews # Build React components ### 2. Development Workflow **Key Scripts**: + +- `npm run compile` - Compile the code - `npm run watch` - Development with hot reload - `npm run test` - Jest-based testing - `npm run lint` - ESLint + Prettier - `npm run build-vsix` - Package extension **Development Environment**: + - Uses VSCode's built-in debugger ("Launch Extension") - Hot reload for webview panels - Python environment auto-detection @@ -141,6 +157,7 @@ npm run panel:webviews # Build React components ### 3. Testing Strategy **Test Configuration** (`jest.config.js`): + - **Unit Tests**: Jest + ts-jest - **Mock System**: Custom VSCode API mocks - **Coverage**: Istanbul-based coverage reporting @@ -151,6 +168,7 @@ npm run panel:webviews # Build React components ### 1. VSCode Extension Dependencies **Required Extensions**: + - `samuelcolvin.jinjahtml` - Jinja templating support - `ms-python.python` - Python environment integration - `altimateai.vscode-altimate-mcp-server` - MCP server @@ -158,12 +176,14 @@ npm run panel:webviews # Build React components ### 2. Major Technical Dependencies **Backend (Node.js)**: + - `inversify` - Dependency injection - `python-bridge` - Python process communication - `zeromq` - Jupyter kernel communication - `@modelcontextprotocol/sdk` - MCP protocol **Frontend (React)**: + - `react` 18 + `react-dom` - `@reduxjs/toolkit` - State management - `antd` - UI component library @@ -172,6 +192,7 @@ npm run panel:webviews # Build React components ### 3. Python Integration **Python Scripts**: + - `dbt_core_integration.py` - Core dbt operations - `dbt_cloud_integration.py` - Cloud API operations - `dbt_healthcheck.py` - Project health analysis @@ -182,6 +203,7 @@ npm run panel:webviews # Build React components ### 1. Extension Configuration **Comprehensive Settings** (190+ configuration options): + - dbt integration mode selection - Query limits and templates - AI features and endpoints @@ -191,6 +213,7 @@ npm run panel:webviews # Build React components ### 2. Language Support **File Type Associations**: + - `jinja-sql` - Primary dbt model files - `jinja-yaml` - dbt configuration files - `jinja-md` - Documentation files @@ -199,6 +222,7 @@ npm run panel:webviews # Build React components ### 3. Command System **80+ Commands Available**: + - Model execution (`dbtPowerUser.runCurrentModel`) - Documentation generation (`dbtPowerUser.generateSchemaYML`) - Query analysis (`dbtPowerUser.sqlLineage`) @@ -209,6 +233,7 @@ npm run panel:webviews # Build React components ### 1. Multi-Platform Distribution **CI/CD Pipeline** (`.github/workflows/ci.yml`): + - **Build Matrix**: macOS, Ubuntu, Windows - **Visual Studio Marketplace**: Primary distribution - **OpenVSX Registry**: Open-source alternative @@ -217,6 +242,7 @@ npm run panel:webviews # Build React components ### 2. Release Process **Automated Release**: + - Git tag triggers release pipeline - Pre-release and stable channel support - Slack notifications for release status @@ -234,16 +260,19 @@ npm run panel:webviews # Build React components ### 2. Adding New Features **For Language Features**: + 1. Create provider in appropriate `*_provider/` directory 2. Register in `inversify.config.ts` 3. Wire up in `DBTPowerUserExtension` **For UI Features**: + 1. Add React component in `webview_panels/src/modules/` 2. Update routing in `AppRoutes.tsx` 3. Add state management slice if needed **For dbt Integration**: + 1. Extend appropriate dbt client (`dbtCoreIntegration.ts` etc.) 2. Add Python bridge function if needed 3. Update MCP server tools if AI-accessible @@ -258,15 +287,19 @@ npm run panel:webviews # Build React components ## Common Development Patterns ### 1. Manifest-Driven Architecture + The extension heavily relies on dbt's `manifest.json` for understanding project structure. Most features key off manifest parsing events. ### 2. Multi-Integration Support + Always consider how features work across dbt core, cloud, and other integration types. Use strategy pattern for integration-specific behavior. ### 3. Webview Communication + Uses VSCode's webview messaging system with typed message contracts. State is synchronized between extension and webview contexts. ### 4. Python Bridge Pattern + For dbt operations requiring Python, use the established bridge pattern with JSON serialization and error handling. This architecture enables the extension to provide comprehensive dbt development support while maintaining modularity and extensibility for future enhancements. @@ -280,6 +313,7 @@ This architecture enables the extension to provide comprehensive dbt development The dbt Power User extension accelerates dbt and SQL development by 3x through three key phases: ### ๐Ÿ”ง DEVELOP + - **SQL Visualizer**: Visual query builder and analyzer - **Query Explanation**: AI-powered SQL query explanation - **Auto-generation**: Generate dbt models from sources or raw SQL @@ -288,7 +322,8 @@ The dbt Power User extension accelerates dbt and SQL development by 3x through t - **Query Translation**: Translate SQL between different dialects - **Compiled SQL Preview**: View compiled dbt code before execution -### ๐Ÿงช TEST +### ๐Ÿงช TEST + - **Query Results Preview**: Execute and analyze query results with export capabilities - **Test Generation**: AI-powered test generation for dbt models - **Column Lineage**: Detailed data lineage with code visibility @@ -297,6 +332,7 @@ The dbt Power User extension accelerates dbt and SQL development by 3x through t - **Model Lineage**: Visual representation of model dependencies ### ๐Ÿค COLLABORATE + - **Documentation Generation**: AI-powered documentation creation - **Code Collaboration**: Discussion threads on code and documentation - **Project Governance**: Automated checks for code quality and standards @@ -307,6 +343,7 @@ The dbt Power User extension accelerates dbt and SQL development by 3x through t ## DataMates AI Integration The extension includes **AI Teammates** through the DataMates Platform: + - **Coaching**: Personalize AI teammates for specific requirements - **Query Assistance**: AI-powered query explanation and optimization - **Documentation**: Automated documentation generation @@ -316,11 +353,13 @@ The extension includes **AI Teammates** through the DataMates Platform: ## Feature Availability **Free Extension Features**: + - SQL Visualizer, Model-level lineage, Auto-generation from sources - Auto-completion, Click to Run, Compiled SQL preview - Query results preview, Defer to production, SQL validation **With Altimate AI Key** (free signup at [app.myaltimate.com](https://app.myaltimate.com)): + - Column-level lineage, Query explanation AI, Query translation AI - Auto-generation from SQL, Test generation AI, Documentation generation AI - Code/documentation collaboration, Lineage export, SaaS UI @@ -333,39 +372,44 @@ The extension includes **AI Teammates** through the DataMates Platform: ## Installation Methods ### Native Installation + Install directly from [VS Code Marketplace](https://marketplace.visualstudio.com/items?itemName=innoverio.vscode-dbt-power-user) or via VS Code: + 1. Open VS Code Extensions panel (`Ctrl+Shift+X`) 2. Search for "dbt Power User" 3. Click Install 4. Reload VS Code if prompted ### Dev Container Installation + Add to your `.devcontainer/devcontainer.json`: + ```json { "customizations": { "vscode": { "files.associations": { "*.yaml": "jinja-yaml", - "*.yml": "jinja-yaml", + "*.yml": "jinja-yaml", "*.sql": "jinja-sql", "*.md": "jinja-md" }, - "extensions": [ - "innoverio.vscode-dbt-power-user" - ] + "extensions": ["innoverio.vscode-dbt-power-user"] } } } ``` ### Cursor IDE Support + The extension is also available for [Cursor IDE](https://www.cursor.com/how-to-install-extension). Install the same way as VS Code. ## Required Configuration ### 1. dbt Integration Setup + Configure how the extension connects to dbt: + - **dbt Core**: For local dbt installations with Python bridge (default) - **dbt Cloud**: For dbt Cloud API integration - **dbt Fusion**: For dbt-fusion CLI integration @@ -374,22 +418,28 @@ Configure how the extension connects to dbt: Set via `dbt.dbtIntegration` setting. #### dbt Fusion Integration + dbt Fusion is a command-line interface that provides enhanced dbt functionality. When using fusion integration: + - Requires dbt-fusion CLI to be installed in your environment - Extension automatically detects fusion installation via `dbt --version` output - Provides full feature support including query execution, compilation, and catalog operations - Uses JSON log format for structured command output parsing ### 2. Python Environment + Ensure Python and dbt are properly installed and accessible. The extension will auto-detect your Python environment through the VS Code Python extension. ### 3. Optional: Altimate AI Key + For advanced AI features, get a free API key: + 1. Sign up at [app.myaltimate.com/register](https://app.myaltimate.com/register) 2. Add API key to `dbt.altimateAiKey` setting 3. Set instance name in `dbt.altimateInstanceName` setting ## Project Setup + 1. Open your dbt project folder in VS Code 2. Run the setup wizard: Select "dbt" in bottom status bar โ†’ "Setup Extension" 3. The extension will auto-install dbt dependencies if enabled @@ -402,52 +452,65 @@ For advanced AI features, get a free API key: ## Quick Diagnostics ### 1. Setup Wizard + Use the built-in setup wizard for automated issue detection: + - Click "dbt" or "dbt is not installed" in bottom status bar -- Select "Setup Extension" +- Select "Setup Extension" - Follow guided setup process ### 2. Diagnostics Command + Run comprehensive system diagnostics: + - Open Command Palette (`Cmd+Shift+P` / `Ctrl+Shift+P`) - Type "diagnostics" โ†’ Select "dbt Power User: Diagnostics" - Review output for environment issues, Python/dbt installation status, and connection problems ### 3. Problems Panel + Check VS Code Problems panel for dbt project issues: + - View โ†’ Problems (or `Ctrl+Shift+M`) - Look for dbt-related validation errors ## Debug Logging Enable detailed logging for troubleshooting: + 1. Command Palette โ†’ "Set Log Level" โ†’ "Debug" 2. View logs: Output panel โ†’ "Log" dropdown โ†’ "dbt" 3. Reproduce the issue to capture debug information ## Developer Tools + For advanced debugging: + - Help โ†’ Toggle Developer Tools - Check console for JavaScript errors and detailed logs ## Common Issues **Extension not recognizing dbt project**: + - Verify `dbt_project.yml` exists in workspace root - Check Python environment has dbt installed - Run diagnostics command for detailed analysis **Python/dbt not found**: + - Configure Python interpreter via VS Code Python extension - Verify dbt is installed in selected Python environment - Set `dbt.dbtPythonPathOverride` if using custom Python path **Connection issues**: + - Verify database connection in dbt profiles - Check firewall/network settings - Review connection details in diagnostics output ## Getting Help + - Join [#tools-dbt-power-user](https://getdbt.slack.com/archives/C05KPDGRMDW) in dbt Community Slack - Contact support at [altimate.ai/support](https://www.altimate.ai/support) - Use in-extension feedback widgets for feature-specific issues @@ -459,38 +522,45 @@ For advanced debugging: ## Auto-completion and Navigation ### Model Auto-completion + - **Smart IntelliSense**: Auto-complete model names with `ref()` function - **Go-to-Definition**: Navigate directly to model files - **Hover Information**: View model details on hover -### Macro Support +### Macro Support + - **Macro Auto-completion**: IntelliSense for custom and built-in macros - **Parameter Hints**: Auto-complete macro parameters - **Definition Navigation**: Jump to macro definitions ### Source Integration + - **Source Auto-completion**: IntelliSense for configured sources - **Column Awareness**: Auto-complete source column names - **Schema Navigation**: Navigate to source definitions ### Documentation Blocks + - **Doc Block Auto-completion**: IntelliSense for documentation references - **Definition Linking**: Navigate to doc block definitions ## Query Development ### SQL Compilation and Preview + - **Compiled Code View**: See final SQL before execution - **Template Resolution**: Preview Jinja templating results - **Syntax Highlighting**: Enhanced SQL syntax highlighting for dbt files ### Query Execution + - **Preview Results**: Execute queries with `Cmd+Enter` / `Ctrl+Enter` - **Result Analysis**: Export results as CSV, copy as JSON - **Query History**: Track executed queries - **Configurable Limits**: Set row limits for query previews (default: 500 rows) ### SQL Formatting + - **Auto-formatting**: Integration with sqlfmt - **Custom Parameters**: Configure formatting rules - **Batch Processing**: Format multiple files @@ -498,17 +568,20 @@ For advanced debugging: ## AI-Powered Development ### Query Explanation + - **Natural Language**: Get plain English explanations of complex SQL - **Step-by-step Analysis**: Breakdown of query logic - **Performance Insights**: Query optimization suggestions ### Code Generation + - **Model from Source**: Generate base models from source tables - **Model from SQL**: Convert raw SQL to dbt models - **Test Generation**: AI-powered test suggestions - **Documentation Generation**: Auto-generate model documentation ### Query Translation + - **Cross-dialect Support**: Translate SQL between database dialects - **Syntax Adaptation**: Handle dialect-specific functions and syntax @@ -526,19 +599,23 @@ This is a MkDocs-based documentation site for the dbt Power User VSCode Extensio ### Architecture #### Content Organization + - `documentation/docs/` contains all documentation content in Markdown format - Content is organized by feature areas: `setup/`, `develop/`, `test/`, `document/`, `govern/`, `discover/`, `teammates/`, `datamates/`, `arch/` - Images and assets are stored within feature-specific directories - `documentation/mkdocs.yml` contains all site configuration #### Key Configuration Files + - `documentation/mkdocs.yml`: Main site configuration including navigation, theme settings, and plugins - `documentation/requirements.txt`: Python dependencies for MkDocs and plugins - `documentation/docs/overrides/`: Custom theme overrides (currently empty) - `documentation/docs/javascripts/`: Custom JavaScript for enhanced functionality #### Theme Configuration + The site uses Material theme with: + - Custom Altimate AI branding and colors - Google Analytics integration (G-LXRSS3VK5N) - Git revision date tracking via plugin @@ -546,7 +623,9 @@ The site uses Material theme with: - Dark/light mode support #### Navigation Structure + Navigation follows a three-phase user journey: + 1. **Setup**: Installation and configuration 2. **Develop**: Core development features 3. **Test**: Testing and validation tools @@ -555,18 +634,21 @@ Navigation follows a three-phase user journey: ### Working with Content #### Adding New Pages + 1. Create `.md` files in the appropriate `docs/` subdirectory 2. Update the `nav` section in `mkdocs.yml` to include the new page 3. Follow existing naming conventions for consistency #### Images and Assets + - Store images in the same directory as the referencing markdown file - Use relative paths for image references - Common assets go in `docs/assets/` #### Internal Links + Use relative markdown links to reference other pages. The site has extensive cross-referencing between related features. ### Testing Changes -Always test locally with `mkdocs serve` before deploying. The development server provides live reload for content changes. \ No newline at end of file +Always test locally with `mkdocs serve` before deploying. The development server provides live reload for content changes. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1985a9137..bda79a72f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -59,6 +59,99 @@ Maintaining consistent code style and formatting is crucial for readability and When introducing new features or addressing bugs, consider the current codebase and community needs. Engage in discussions with fellow community members if you're unsure about design decisions or implementation details. +## Local Development with @altimateai/dbt-integration Library + +When working on the extension, you may need to develop against a local version of the `@altimateai/dbt-integration` library instead of the published npm package. This allows you to test changes to both the extension and the integration library simultaneously. + +### Setup for Local Development + +1. **Clone the dbt-integration repository:** First, ensure you have the `altimate-dbt-integration` repository cloned as a sibling directory to this project: + + ```bash + cd /path/to/your/projects + git clone https://github.com/altimateai/altimate-dbt-integration.git + cd vscode-dbt-power-user + ``` + + Your directory structure should look like: + + ``` + /path/to/your/projects/ + โ”œโ”€โ”€ altimate-dbt-integration/ + โ””โ”€โ”€ vscode-dbt-power-user/ + ``` + +2. **Switch to local development mode:** Modify the following configuration files to use the local TypeScript source instead of the npm package: + + **jest.config.js**: Uncomment the local development lines: + + ```javascript + // Development: use local TypeScript source (same as webpack and tsconfig) + "^@altimateai/dbt-integration$": + "/../altimate-dbt-integration/src/index.ts", + // Production: use npm package (commented out for development) + // "^@altimateai/dbt-integration$": "@altimateai/dbt-integration", + ``` + + **tsconfig.json**: Update the configuration: + + ```json + { + // "rootDir": "src", + "rootDirs": ["src", "../altimate-dbt-integration/src"], + "paths": { + "@altimateai/dbt-integration": [ + "../altimate-dbt-integration/src/index.ts" + ], + "@extension": ["./src/modules.ts"], + "@lib": ["./src/lib/index"] + } + } + ``` + + **webpack.config.js**: Update the alias and copy plugin configurations: + + ```javascript + // In resolve.alias section: + "@altimateai/dbt-integration": path.resolve( + __dirname, + "../altimate-dbt-integration/src/index.ts", + ), + + // In CopyWebpackPlugin, comment out production copies and uncomment development copies: + // Development: use local Python files + { + from: path.resolve( + __dirname, + "../altimate-dbt-integration/node_modules/python-bridge/node_python_bridge.py", + ), + to: "node_python_bridge.py", + }, + // ... (other local file copies) + ``` + +### Switching Back to Production Mode + +When you're done with local development, revert the configuration changes to use the published npm package: + +1. **jest.config.js**: Comment out local development lines and uncomment production lines +2. **tsconfig.json**: Set `"rootDir": "src"` and remove the local path mapping +3. **webpack.config.js**: Remove local alias and use npm package copies in CopyWebpackPlugin + +### Benefits of Local Development Mode + +- **Real-time changes**: Modify both the extension and integration library simultaneously +- **Debugging**: Set breakpoints and debug across both codebases +- **Testing**: Test integration library changes before publishing +- **Development workflow**: Faster iteration when working on features that span both repositories + +### Important Notes + +- Ensure both repositories are on compatible branches when doing local development +- The local development setup expects the `altimate-dbt-integration` directory to be a sibling of `vscode-dbt-power-user` +- Always test with the production npm package configuration before submitting pull requests +- The Python files from the integration library are copied during the webpack build process + ## Testing Comprehensive testing is essential for maintaining the extension's stability and reliability. While adding new features or fixing bugs, run tests locally helps ensure your changes don't introduce regressions. diff --git a/altimate_packages/altimate/__init__.py b/altimate_packages/altimate/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/altimate_packages/altimate/fetch_schema.py b/altimate_packages/altimate/fetch_schema.py deleted file mode 100644 index ba272a58f..000000000 --- a/altimate_packages/altimate/fetch_schema.py +++ /dev/null @@ -1,35 +0,0 @@ -import sqlglot -from sqlglot.optimizer.qualify import qualify -from altimate.utils import map_adapter_to_dialect - - -def fetch_schema(sql: str, dialect: str): - parsed_query = sqlglot.parse_one(sql=sql, dialect=map_adapter_to_dialect(dialect)) - columns = [] - for c in parsed_query.selects: - if c.key == "column": - if c.args["this"].key == "star": - raise Exception( - f"unable fetched schema due to star: {c.sql(pretty=True)}" - ) - columns.append(c.alias_or_name) - elif c.key == "alias": - columns.append(c.alias_or_name) - else: - raise Exception(f"unknown key '{c.key}' detected for {c.sql(pretty=True)}") - return columns - - -def validate_whether_sql_has_columns(sql: str, dialect: str): - try: - parsed_query = sqlglot.parse_one(sql=sql, dialect=map_adapter_to_dialect(dialect)) - qualify( - parsed_query, - schema={}, - dialect=dialect, - quote_identifiers=False, - validate_qualify_columns=True, - ) - return True - except Exception as e: - return False \ No newline at end of file diff --git a/altimate_packages/altimate/utils.py b/altimate_packages/altimate/utils.py deleted file mode 100644 index 9c823e38a..000000000 --- a/altimate_packages/altimate/utils.py +++ /dev/null @@ -1,353 +0,0 @@ -import re - -import sqlglot -from sqlglot.executor import execute -from sqlglot.expressions import Table -from sqlglot.optimizer import traverse_scope -from sqlglot.optimizer.qualify import qualify - -ADAPTER_MAPPING = { - "bigquery": "bigquery", - "clickhouse": "clickhouse", - "databricks": "databricks", - "duckdb": "duckdb", - "hive": "hive", - "mysql": "mysql", - "oracle": "oracle", - "postgres": "postgres", - "redshift": "redshift", - "snowflake": "snowflake", - "spark": "spark", - "starrocks": "starrocks", - "teradata": "teradata", - "trino": "trino", - "synapse": "tsql", - "sqlserver": "tsql", - "doris": "doris", - "athena": "presto", -} - -MULTIPLE_OCCURENCES_STR = "Unable to highlight the exact location in the SQL code due to multiple occurrences." -MAPPING_FAILED_STR = "Unable to highlight the exact location in the SQL code." - - -def extract_column_name(text): - # List of regex patterns - regex_patterns = [ - r"Column '\"(\w+)\"' could not be resolved", - r"Unknown column: (\w+)", - r"Column '(\w+)' could not be resolved", - r"Unknown output column: (\w+)", - r"Cannot automatically join: (\w+)", - ] - - # Iterate over each regex pattern - for regex in regex_patterns: - matches = re.findall(regex, text) - if matches: - return matches[0] - - return None - - -def find_single_occurrence_indices(main_string, substring): - # Convert both strings to lowercase for case-insensitive comparison - main_string = main_string.lower() - substring = substring.lower() if substring else "" - - if not substring: - # return consistent tuple when substring is empty - return None, None, 0 - - num_occurrences = main_string.count(substring) - # Check if the substring occurs only once in the main string - if num_occurrences == 1: - start_index = main_string.find(substring) - return start_index, start_index + len(substring), num_occurrences - - # Return None if the substring doesn't occur exactly once - return None, None, num_occurrences - - -def map_adapter_to_dialect(adapter: str): - return ADAPTER_MAPPING.get(adapter, adapter) - - -def get_str_position(str, row, col): - """ - Get the position of a grid position in a string - """ - lines = str.split("\n") - position = 0 - for i in range(row - 1): - position += len(lines[i]) + 1 - position += col - return position - - -def get_line_and_column_from_position(text, start_index): - """ - Finds the grid position (row and column) in a multiline string given a Python start index. - Rows and columns are 1-indexed. - - :param text: Multiline string. - :param start_index: Python start index (0-indexed). - :return: Tuple of (row, column). - """ - row = 0 - current_length = 0 - - # Split the text into lines - lines = text.split("\n") - - for line in lines: - # Check if the start_index is in the current line - if current_length + len(line) >= start_index: - # Column is the difference between start_index and the length of processed characters - column = start_index - current_length + 1 - return row, column - - # Update the row and current length for the next iteration - row += 1 - current_length += len(line) + 1 # +1 for the newline character - - return None, None - - -def _build_message(sql: str, error: dict): - len_highlight = len(error.get("highlight", "")) - len_prefix = len(error.get("start_context", "")) - if error.get("line") and error.get("col"): - end_position = get_str_position(sql, error["line"], error["col"]) - start_position = end_position - len_highlight - len_prefix - row, col = get_line_and_column_from_position(sql, start_position) - return { - "description": "Failed to parse the sql query", - "start_position": [row, col], - "end_position": [error["line"], error["col"]], - } - return {"description": "Failed to parse the sql query"} - - -def sql_parse_errors(sql: str, dialect: str): - errors = [] - try: - sqlglot.transpile(sql, read=dialect) - ast = sqlglot.parse_one(sql, read=dialect) - if isinstance(ast, sqlglot.exp.Alias): - return [ - { - "description": "Failed to parse the sql query.", - } - ] - except sqlglot.errors.ParseError as e: - for error in e.errors: - errors.append(_build_message(sql, error)) - return errors - - -def get_start_and_end_position(sql: str, invalid_string: str): - start, end, num_occurences = find_single_occurrence_indices(sql, invalid_string) - if start and end: - return ( - list(get_line_and_column_from_position(sql, start)), - list(get_line_and_column_from_position(sql, end)), - num_occurences, - ) - return None, None, num_occurences - - -def form_error( - error: str, invalid_entity: str, start_position, end_position, num_occurences -): - if num_occurences > 1: - error = ( - f"{error}\n {MULTIPLE_OCCURENCES_STR.format(invalid_entity=invalid_entity)}" - ) - return { - "description": error, - } - - if not start_position or not end_position: - error = ( - f"{error}\n {MAPPING_FAILED_STR.format(invalid_entity=invalid_entity)}" - if invalid_entity - else error - ) - return { - "description": error, - } - - return { - "description": error, - "start_position": start_position, - "end_position": end_position, - } - - -def validate_tables_and_columns( - sql: str, - dialect: str, - schemas: dict, -): - try: - parsed_sql = sqlglot.parse_one(sql, read=dialect) - qualify(parsed_sql, dialect=dialect, schema=schemas) - except sqlglot.errors.OptimizeError as e: - error = str(e) - if "sqlglot" in error: - error = "Failed to validate the query." - invalid_entity = extract_column_name(error) - if not invalid_entity: - return [ - { - "description": error, - } - ] - start_position, end_position, num_occurences = get_start_and_end_position( - sql, invalid_entity - ) - error = error if error[-1] == "." else error + "." - return [ - form_error( - error, invalid_entity, start_position, end_position, num_occurences - ) - ] - - return None - - -def sql_execute_errors( - sql: str, - dialect: str, - schemas: dict, -): - tables = {} - for db in schemas: - if db not in tables: - tables[db] = {} - for schema in schemas[db]: - if schema not in tables[db]: - tables[db][schema] = {} - for table in schemas[db][schema]: - tables[db][schema][table] = [] - - try: - execute( - sql=sql, - read=dialect, - schema=schemas, - tables=tables, - ) - except sqlglot.errors.ExecuteError as e: - return [ - { - "description": str(e), - } - ] - return None - - -def qualify_columns(expression): - """ - Qualify the columns in the given SQL expression. - """ - try: - return qualify( - expression, - qualify_columns=True, - isolate_tables=True, - validate_qualify_columns=False, - ) - except sqlglot.errors.OptimizeError as error: - return expression - - -def parse_sql_query(sql_query, dialect): - """ - Parses the SQL query and returns an AST. - """ - return sqlglot.parse_one(sql_query, read=dialect) - - -def extract_physical_columns(ast): - """ - Extracts physical columns from the given AST. - """ - physical_columns = {} - for scope in traverse_scope(ast): - for column in scope.columns: - table = scope.sources.get(column.table) - if isinstance(table, Table): - db, schema, table_name = table.catalog, table.db, table.name - if db is None or schema is None: - continue - path = f"{db}.{schema}.{table_name}".lower() - physical_columns.setdefault(path, set()).add(column.name) - return physical_columns - - -def get_columns_used(sql_query, dialect): - """ - Process the SQL query to extract physical columns. - """ - ast = parse_sql_query(sql_query, dialect) - qualified_ast = qualify_columns(ast) - return extract_physical_columns(qualified_ast) - - -def validate_columns_present_in_schema(sql_query, dialect, schemas, model_mapping): - """ - Validate that the columns in the SQL query are present in the schema. - """ - errors = [] - new_schemas = {} - for db in schemas: - for schema in schemas[db]: - for table in schemas[db][schema]: - path = f"{db}.{schema}.{table}".lower() - new_schemas.setdefault(path, set()).update( - [column.lower() for column in schemas[db][schema][table].keys()] - ) - schemas = new_schemas - try: - columns_used = get_columns_used(sql_query, dialect) - - for table, columns_set in columns_used.items(): - if table not in schemas: - ( - start_position, - end_position, - num_occurences, - ) = get_start_and_end_position(sql_query, table) - error = f"Error: Table '{table}' not found. This issue often occurs when a table is used directly\n in dbt instead of being referenced through the appropriate syntax.\n To resolve this, ensure that '{table}' is propaerly defined in your project and use the 'ref()' function to reference it in your models." - - errors.append( - form_error( - error, table, start_position, end_position, num_occurences - ) - ) - continue - - columns = schemas[table] - for column in columns_set: - if column.lower() not in columns: - ( - start_position, - end_position, - num_occurences, - ) = get_start_and_end_position(sql_query, column) - table = model_mapping.get(table, table) - error = f"Error: Column '{column}' not found in '{table}'. \nPossible causes: 1) Typo in column name. 2) Column not materialized. 3) Column not selected in parent cte." - errors.append( - form_error( - error, - column, - start_position, - end_position, - num_occurences, - ) - ) - except Exception as e: - pass - return errors diff --git a/altimate_packages/altimate/validate_sql.py b/altimate_packages/altimate/validate_sql.py deleted file mode 100644 index dd32578b9..000000000 --- a/altimate_packages/altimate/validate_sql.py +++ /dev/null @@ -1,114 +0,0 @@ -from typing import Dict, List - -from altimate.utils import ( - map_adapter_to_dialect, - sql_execute_errors, - sql_parse_errors, - validate_columns_present_in_schema, - validate_tables_and_columns, -) - - -def _get_key( - key: str, - dialect: str, -): - if dialect == "bigquery": - return key.lower() - - if dialect == "snowflake": - return key.upper() - return key - - -def _build_schemas( - models: List[Dict], - dialect: str, -): - """ - TODO: Duplicated in multiple places with slight variations. Fix this. - """ - schemas = {} - for model in models: - schema = {} - for column in model["columns"]: - schema[_get_key(model["columns"][column]["name"], dialect)] = model[ - "columns" - ][column].get("data_type", "string") - - db = _get_key(model["database"], dialect) - schema_name = _get_key(model["schema"], dialect) - table = _get_key(model["alias"], dialect) - if db not in schemas: - schemas[db] = {} - - if schema_name not in schemas[db]: - schemas[db][schema_name] = {} - - schemas[db][schema_name][table] = schema - - return schemas - - -def _build_model_mapping( - models: List[Dict], -): - model_map = {} - for model in models: - db = model["database"] - schema = model["schema"] - table = model["alias"] - model_map[f"{db}.{schema}.{table}".lower()] = model["name"] - return model_map - - -def validate_sql_from_models( - sql: str, - dialect: str, - models: List[Dict], -): - """ - Validate SQL from models - """ - try: - dialect = map_adapter_to_dialect(dialect) - schemas = _build_schemas(models, dialect) - model_mapping = _build_model_mapping(models) - errors = sql_parse_errors(sql, dialect) - - if len(errors) > 0: - return { - "error_type": "sql_parse_error", - "errors": errors, - } - - errors = validate_columns_present_in_schema( - sql, dialect, schemas, model_mapping - ) - if len(errors) > 0: - return { - "error_type": "sql_invalid_error", - "errors": errors, - } - - errors = validate_tables_and_columns(sql, dialect, schemas) - - if errors: - return { - "error_type": "sql_invalid_error", - "errors": errors, - } - - # errors = sql_execute_errors(sql, dialect, schemas) - - # if errors: - # return {"error_type": "sql_execute_error", "errors": errors} - - except Exception as e: - return { - "error_type": "sql_unknown_error", - "errors": [ - {"description": f"Unknown error. Cannot validate SQL. {str(e)}"} - ], - } - return {} diff --git a/altimate_packages/sqlglot/LICENSE b/altimate_packages/sqlglot/LICENSE deleted file mode 100644 index 8a19fcfef..000000000 --- a/altimate_packages/sqlglot/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2023 Toby Mao - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. \ No newline at end of file diff --git a/altimate_packages/sqlglot/README.md b/altimate_packages/sqlglot/README.md deleted file mode 100644 index 1104903f0..000000000 --- a/altimate_packages/sqlglot/README.md +++ /dev/null @@ -1,3 +0,0 @@ -### Sqlglot - -Currently using version 18.3.0 of sqlglot diff --git a/altimate_packages/sqlglot/__init__.py b/altimate_packages/sqlglot/__init__.py deleted file mode 100644 index 33d347469..000000000 --- a/altimate_packages/sqlglot/__init__.py +++ /dev/null @@ -1,178 +0,0 @@ -# ruff: noqa: F401 -""" -.. include:: ../README.md - ----- -""" - -from __future__ import annotations - -import logging -import typing as t - -from sqlglot import expressions as exp -from sqlglot.dialects.dialect import Dialect as Dialect, Dialects as Dialects -from sqlglot.diff import diff as diff -from sqlglot.errors import ( - ErrorLevel as ErrorLevel, - ParseError as ParseError, - TokenError as TokenError, - UnsupportedError as UnsupportedError, -) -from sqlglot.expressions import ( - Expression as Expression, - alias_ as alias, - and_ as and_, - case as case, - cast as cast, - column as column, - condition as condition, - delete as delete, - except_ as except_, - from_ as from_, - func as func, - insert as insert, - intersect as intersect, - maybe_parse as maybe_parse, - merge as merge, - not_ as not_, - or_ as or_, - select as select, - subquery as subquery, - table_ as table, - to_column as to_column, - to_identifier as to_identifier, - to_table as to_table, - union as union, -) -from sqlglot.generator import Generator as Generator -from sqlglot.parser import Parser as Parser -from sqlglot.schema import MappingSchema as MappingSchema, Schema as Schema -from sqlglot.tokens import Token as Token, Tokenizer as Tokenizer, TokenType as TokenType - -if t.TYPE_CHECKING: - from sqlglot._typing import E - from sqlglot.dialects.dialect import DialectType as DialectType - -logger = logging.getLogger("sqlglot") - - -try: - from sqlglot._version import __version__, __version_tuple__ -except ImportError: - logger.error( - "Unable to set __version__, run `pip install -e .` or `python setup.py develop` first." - ) - - -pretty = False -"""Whether to format generated SQL by default.""" - - -def tokenize(sql: str, read: DialectType = None, dialect: DialectType = None) -> t.List[Token]: - """ - Tokenizes the given SQL string. - - Args: - sql: the SQL code string to tokenize. - read: the SQL dialect to apply during tokenizing (eg. "spark", "hive", "presto", "mysql"). - dialect: the SQL dialect (alias for read). - - Returns: - The resulting list of tokens. - """ - return Dialect.get_or_raise(read or dialect).tokenize(sql) - - -def parse( - sql: str, read: DialectType = None, dialect: DialectType = None, **opts -) -> t.List[t.Optional[Expression]]: - """ - Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement. - - Args: - sql: the SQL code string to parse. - read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). - dialect: the SQL dialect (alias for read). - **opts: other `sqlglot.parser.Parser` options. - - Returns: - The resulting syntax tree collection. - """ - return Dialect.get_or_raise(read or dialect).parse(sql, **opts) - - -@t.overload -def parse_one(sql: str, *, into: t.Type[E], **opts) -> E: ... - - -@t.overload -def parse_one(sql: str, **opts) -> Expression: ... - - -def parse_one( - sql: str, - read: DialectType = None, - dialect: DialectType = None, - into: t.Optional[exp.IntoType] = None, - **opts, -) -> Expression: - """ - Parses the given SQL string and returns a syntax tree for the first parsed SQL statement. - - Args: - sql: the SQL code string to parse. - read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). - dialect: the SQL dialect (alias for read) - into: the SQLGlot Expression to parse into. - **opts: other `sqlglot.parser.Parser` options. - - Returns: - The syntax tree for the first parsed statement. - """ - - dialect = Dialect.get_or_raise(read or dialect) - - if into: - result = dialect.parse_into(into, sql, **opts) - else: - result = dialect.parse(sql, **opts) - - for expression in result: - if not expression: - raise ParseError(f"No expression was parsed from '{sql}'") - return expression - else: - raise ParseError(f"No expression was parsed from '{sql}'") - - -def transpile( - sql: str, - read: DialectType = None, - write: DialectType = None, - identity: bool = True, - error_level: t.Optional[ErrorLevel] = None, - **opts, -) -> t.List[str]: - """ - Parses the given SQL string in accordance with the source dialect and returns a list of SQL strings transformed - to conform to the target dialect. Each string in the returned list represents a single transformed SQL statement. - - Args: - sql: the SQL code string to transpile. - read: the source dialect used to parse the input string (eg. "spark", "hive", "presto", "mysql"). - write: the target dialect into which the input should be transformed (eg. "spark", "hive", "presto", "mysql"). - identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both: - the source and the target dialect. - error_level: the desired error level of the parser. - **opts: other `sqlglot.generator.Generator` options. - - Returns: - The list of transpiled SQL statements. - """ - write = (read if write is None else write) if identity else write - write = Dialect.get_or_raise(write) - return [ - write.generate(expression, copy=False, **opts) if expression else "" - for expression in parse(sql, read, error_level=error_level) - ] diff --git a/altimate_packages/sqlglot/__main__.py b/altimate_packages/sqlglot/__main__.py deleted file mode 100644 index 5a77409ed..000000000 --- a/altimate_packages/sqlglot/__main__.py +++ /dev/null @@ -1,96 +0,0 @@ -from __future__ import annotations - -import argparse -import sys -import typing as t - -import sqlglot - -parser = argparse.ArgumentParser(description="Transpile SQL") -parser.add_argument( - "sql", - metavar="sql", - type=str, - help="SQL statement(s) to transpile, or - to parse stdin.", -) -parser.add_argument( - "--read", - dest="read", - type=str, - default=None, - help="Dialect to read default is generic", -) -parser.add_argument( - "--write", - dest="write", - type=str, - default=None, - help="Dialect to write default is generic", -) -parser.add_argument( - "--no-identify", - dest="identify", - action="store_false", - help="Don't auto identify fields", -) -parser.add_argument( - "--no-pretty", - dest="pretty", - action="store_false", - help="Compress sql", -) -parser.add_argument( - "--parse", - dest="parse", - action="store_true", - help="Parse and return the expression tree", -) -parser.add_argument( - "--tokenize", - dest="tokenize", - action="store_true", - help="Tokenize and return the tokens list", -) -parser.add_argument( - "--error-level", - dest="error_level", - type=str, - default="IMMEDIATE", - help="IGNORE, WARN, RAISE, IMMEDIATE (default)", -) -parser.add_argument( - "--version", - action="version", - version=sqlglot.__version__, - help="Display the SQLGlot version", -) - - -args = parser.parse_args() -error_level = sqlglot.ErrorLevel[args.error_level.upper()] - -sql = sys.stdin.read() if args.sql == "-" else args.sql - -if args.parse: - objs: t.Union[t.List[str], t.List[sqlglot.tokens.Token]] = [ - repr(expression) - for expression in sqlglot.parse( - sql, - read=args.read, - error_level=error_level, - ) - ] -elif args.tokenize: - objs = sqlglot.Dialect.get_or_raise(args.read).tokenize(sql) -else: - objs = sqlglot.transpile( - sql, - read=args.read, - write=args.write, - identify=args.identify, - pretty=args.pretty, - error_level=error_level, - ) - -for obj in objs: - print(obj) diff --git a/altimate_packages/sqlglot/_typing.py b/altimate_packages/sqlglot/_typing.py deleted file mode 100644 index 0415aa41f..000000000 --- a/altimate_packages/sqlglot/_typing.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import annotations - -import typing as t - -import sqlglot - -if t.TYPE_CHECKING: - from typing_extensions import Literal as Lit # noqa - -# A little hack for backwards compatibility with Python 3.7. -# For example, we might want a TypeVar for objects that support comparison e.g. SupportsRichComparisonT from typeshed. -# But Python 3.7 doesn't support Protocols, so we'd also need typing_extensions, which we don't want as a dependency. -A = t.TypeVar("A", bound=t.Any) -B = t.TypeVar("B", bound="sqlglot.exp.Binary") -E = t.TypeVar("E", bound="sqlglot.exp.Expression") -F = t.TypeVar("F", bound="sqlglot.exp.Func") -T = t.TypeVar("T") diff --git a/altimate_packages/sqlglot/dataframe/README.md b/altimate_packages/sqlglot/dataframe/README.md deleted file mode 100644 index 01be08ec0..000000000 --- a/altimate_packages/sqlglot/dataframe/README.md +++ /dev/null @@ -1,242 +0,0 @@ -# PySpark DataFrame SQL Generator - -This is a drop-in replacement for the PySpark DataFrame API that will generate SQL instead of executing DataFrame operations directly. This, when combined with the transpiling support in SQLGlot, allows one to write PySpark DataFrame code and execute it on other engines like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). - -Currently many of the common operations are covered and more functionality will be added over time. Please [open an issue](https://github.com/tobymao/sqlglot/issues) or [PR](https://github.com/tobymao/sqlglot/pulls) with your feedback or contribution to help influence what should be prioritized next and make sure your use case is properly supported. - -# How to use - -## Instructions - -- [Install SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md#install) and that is all that is required to just generate SQL. [The examples](#examples) show generating SQL and then executing that SQL on a specific engine and that will require that engine's client library. -- Find/replace all `from pyspark.sql` with `from sqlglot.dataframe`. -- Prior to any `spark.read.table` or `spark.table` run `sqlglot.schema.add_table('', , dialect="spark")`. - - The column structure can be defined the following ways: - - Dictionary where the keys are column names and values are string of the Spark SQL type name. - - Ex: `{'cola': 'string', 'colb': 'int'}` - - PySpark DataFrame `StructType` similar to when using `createDataFrame`. - - Ex: `StructType([StructField('cola', StringType()), StructField('colb', IntegerType())])` - - A string of names and types similar to what is supported in `createDataFrame`. - - Ex: `cola: STRING, colb: INT` - - [Not Recommended] A list of string column names without type. - - Ex: `['cola', 'colb']` - - The lack of types may limit functionality in future releases. - - See [Registering Custom Schema](#registering-custom-schema-class) for information on how to skip this step if the information is stored externally. -- If your output SQL dialect is not Spark, then configure the SparkSession to use that dialect - - Ex: `SparkSession().builder.config("sqlframe.dialect", "bigquery").getOrCreate()` - - See [dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) for a full list of dialects. -- Add `.sql(pretty=True)` to your final DataFrame command to return a list of sql statements to run that command. - - In most cases a single SQL statement is returned. Currently the only exception is when caching DataFrames which isn't supported in other dialects. - - Ex: `.sql(pretty=True)` - -## Examples - -```python -import sqlglot as sqlglot -from sqlglot.dataframe.sql.session import SparkSession -from sqlglot.dataframe.sql import functions as F - -dialect = "spark" - -sqlglot.schema.add_table( - 'employee', - { - 'employee_id': 'INT', - 'fname': 'STRING', - 'lname': 'STRING', - 'age': 'INT', - }, - dialect=dialect, -) # Register the table structure prior to reading from the table - -spark = SparkSession.builder.config("sqlframe.dialect", dialect).getOrCreate() - -df = ( - spark - .table('employee') - .groupBy(F.col("age")) - .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) -) - -print(df.sql(pretty=True)) -``` - -```sparksql -SELECT - `employee`.`age` AS `age`, - COUNT(DISTINCT `employee`.`employee_id`) AS `num_employees` -FROM `employee` AS `employee` -GROUP BY - `employee`.`age` -``` - -## Registering Custom Schema Class - -The step of adding `sqlglot.schema.add_table` can be skipped if you have the column structure stored externally like in a file or from an external metadata table. This can be done by writing a class that implements the `sqlglot.schema.Schema` abstract class and then assigning that class to `sqlglot.schema`. - -```python -import sqlglot as sqlglot -from sqlglot.dataframe.sql.session import SparkSession -from sqlglot.dataframe.sql import functions as F -from sqlglot.schema import Schema - - -class ExternalSchema(Schema): - ... - -sqlglot.schema = ExternalSchema() - -spark = SparkSession() # Spark will be used by default is not specific in SparkSession config - -df = ( - spark - .table('employee') - .groupBy(F.col("age")) - .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) -) - -print(df.sql(pretty=True)) -``` - -## Example Implementations - -### Bigquery - -```python -from google.cloud import bigquery -from sqlglot.dataframe.sql.session import SparkSession -from sqlglot.dataframe.sql import types -from sqlglot.dataframe.sql import functions as F - -client = bigquery.Client() - -data = [ - (1, "Jack", "Shephard", 34), - (2, "John", "Locke", 48), - (3, "Kate", "Austen", 34), - (4, "Claire", "Littleton", 22), - (5, "Hugo", "Reyes", 26), -] -schema = types.StructType([ - types.StructField('employee_id', types.IntegerType(), False), - types.StructField('fname', types.StringType(), False), - types.StructField('lname', types.StringType(), False), - types.StructField('age', types.IntegerType(), False), -]) - -sql_statements = ( - SparkSession - .builder - .config("sqlframe.dialect", "bigquery") - .getOrCreate() - .createDataFrame(data, schema) - .groupBy(F.col("age")) - .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) - .sql() -) - -result = None -for sql in sql_statements: - result = client.query(sql) - -assert result is not None -for row in client.query(result): - print(f"Age: {row['age']}, Num Employees: {row['num_employees']}") -``` - -### Snowflake - -```python -import os - -import snowflake.connector -from sqlglot.dataframe.session import SparkSession -from sqlglot.dataframe import types -from sqlglot.dataframe import functions as F - -ctx = snowflake.connector.connect( - user=os.environ["SNOWFLAKE_USER"], - password=os.environ["SNOWFLAKE_PASS"], - account=os.environ["SNOWFLAKE_ACCOUNT"] -) -cs = ctx.cursor() - -data = [ - (1, "Jack", "Shephard", 34), - (2, "John", "Locke", 48), - (3, "Kate", "Austen", 34), - (4, "Claire", "Littleton", 22), - (5, "Hugo", "Reyes", 26), -] -schema = types.StructType([ - types.StructField('employee_id', types.IntegerType(), False), - types.StructField('fname', types.StringType(), False), - types.StructField('lname', types.StringType(), False), - types.StructField('age', types.IntegerType(), False), -]) - -sql_statements = ( - SparkSession - .builder - .config("sqlframe.dialect", "snowflake") - .getOrCreate() - .createDataFrame(data, schema) - .groupBy(F.col("age")) - .agg(F.countDistinct(F.col("lname")).alias("num_employees")) - .sql() -) - -try: - for sql in sql_statements: - cs.execute(sql) - results = cs.fetchall() - for row in results: - print(f"Age: {row[0]}, Num Employees: {row[1]}") -finally: - cs.close() -ctx.close() -``` - -### Spark - -```python -from pyspark.sql.session import SparkSession as PySparkSession -from sqlglot.dataframe.sql.session import SparkSession -from sqlglot.dataframe.sql import types -from sqlglot.dataframe.sql import functions as F - -data = [ - (1, "Jack", "Shephard", 34), - (2, "John", "Locke", 48), - (3, "Kate", "Austen", 34), - (4, "Claire", "Littleton", 22), - (5, "Hugo", "Reyes", 26), -] -schema = types.StructType([ - types.StructField('employee_id', types.IntegerType(), False), - types.StructField('fname', types.StringType(), False), - types.StructField('lname', types.StringType(), False), - types.StructField('age', types.IntegerType(), False), -]) - -sql_statements = ( - SparkSession() - .createDataFrame(data, schema) - .groupBy(F.col("age")) - .agg(F.countDistinct(F.col("employee_id")).alias("num_employees")) - .sql() -) - -pyspark = PySparkSession.builder.master("local[*]").getOrCreate() - -df = None -for sql in sql_statements: - df = pyspark.sql(sql) - -assert df is not None -df.show() -``` - -# Unsupportable Operations - -Any operation that lacks a way to represent it in SQL cannot be supported by this tool. An example of this would be rdd operations. Since the DataFrame API though is mostly modeled around SQL concepts most operations can be supported. diff --git a/altimate_packages/sqlglot/dataframe/__init__.py b/altimate_packages/sqlglot/dataframe/__init__.py deleted file mode 100644 index a57e99013..000000000 --- a/altimate_packages/sqlglot/dataframe/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -.. include:: ./README.md -""" diff --git a/altimate_packages/sqlglot/dataframe/sql/__init__.py b/altimate_packages/sqlglot/dataframe/sql/__init__.py deleted file mode 100644 index 3f9080277..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -from sqlglot.dataframe.sql.column import Column -from sqlglot.dataframe.sql.dataframe import DataFrame, DataFrameNaFunctions -from sqlglot.dataframe.sql.group import GroupedData -from sqlglot.dataframe.sql.readwriter import DataFrameReader, DataFrameWriter -from sqlglot.dataframe.sql.session import SparkSession -from sqlglot.dataframe.sql.window import Window, WindowSpec - -__all__ = [ - "SparkSession", - "DataFrame", - "GroupedData", - "Column", - "DataFrameNaFunctions", - "Window", - "WindowSpec", - "DataFrameReader", - "DataFrameWriter", -] diff --git a/altimate_packages/sqlglot/dataframe/sql/_typing.py b/altimate_packages/sqlglot/dataframe/sql/_typing.py deleted file mode 100644 index fb46026fd..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/_typing.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -import datetime -import typing as t - -from sqlglot import expressions as exp - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql.column import Column - from sqlglot.dataframe.sql.types import StructType - -ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] -ColumnOrName = t.Union[Column, str] -ColumnOrLiteral = t.Union[ - Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime -] -SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]] -OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert] diff --git a/altimate_packages/sqlglot/dataframe/sql/column.py b/altimate_packages/sqlglot/dataframe/sql/column.py deleted file mode 100644 index b1569d34c..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/column.py +++ /dev/null @@ -1,332 +0,0 @@ -from __future__ import annotations - -import typing as t - -import sqlglot as sqlglot -from sqlglot import expressions as exp -from sqlglot.dataframe.sql.types import DataType -from sqlglot.helper import flatten, is_iterable - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnOrLiteral - from sqlglot.dataframe.sql.window import WindowSpec - - -class Column: - def __init__(self, expression: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]): - from sqlglot.dataframe.sql.session import SparkSession - - if isinstance(expression, Column): - expression = expression.expression # type: ignore - elif expression is None or not isinstance(expression, (str, exp.Expression)): - expression = self._lit(expression).expression # type: ignore - elif not isinstance(expression, exp.Column): - expression = sqlglot.maybe_parse(expression, dialect=SparkSession().dialect).transform( - SparkSession().dialect.normalize_identifier, copy=False - ) - if expression is None: - raise ValueError(f"Could not parse {expression}") - - self.expression: exp.Expression = expression # type: ignore - - def __repr__(self): - return repr(self.expression) - - def __hash__(self): - return hash(self.expression) - - def __eq__(self, other: ColumnOrLiteral) -> Column: # type: ignore - return self.binary_op(exp.EQ, other) - - def __ne__(self, other: ColumnOrLiteral) -> Column: # type: ignore - return self.binary_op(exp.NEQ, other) - - def __gt__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.GT, other) - - def __ge__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.GTE, other) - - def __lt__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.LT, other) - - def __le__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.LTE, other) - - def __and__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.And, other) - - def __or__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Or, other) - - def __mod__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Mod, other) - - def __add__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Add, other) - - def __sub__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Sub, other) - - def __mul__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Mul, other) - - def __truediv__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Div, other) - - def __div__(self, other: ColumnOrLiteral) -> Column: - return self.binary_op(exp.Div, other) - - def __neg__(self) -> Column: - return self.unary_op(exp.Neg) - - def __radd__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Add, other) - - def __rsub__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Sub, other) - - def __rmul__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Mul, other) - - def __rdiv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Div, other) - - def __rtruediv__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Div, other) - - def __rmod__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Mod, other) - - def __pow__(self, power: ColumnOrLiteral, modulo=None): - return Column(exp.Pow(this=self.expression, expression=Column(power).expression)) - - def __rpow__(self, power: ColumnOrLiteral): - return Column(exp.Pow(this=Column(power).expression, expression=self.expression)) - - def __invert__(self): - return self.unary_op(exp.Not) - - def __rand__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.And, other) - - def __ror__(self, other: ColumnOrLiteral) -> Column: - return self.inverse_binary_op(exp.Or, other) - - @classmethod - def ensure_col(cls, value: t.Optional[t.Union[ColumnOrLiteral, exp.Expression]]) -> Column: - return cls(value) - - @classmethod - def ensure_cols(cls, args: t.List[t.Union[ColumnOrLiteral, exp.Expression]]) -> t.List[Column]: - return [cls.ensure_col(x) if not isinstance(x, Column) else x for x in args] - - @classmethod - def _lit(cls, value: ColumnOrLiteral) -> Column: - if isinstance(value, dict): - columns = [cls._lit(v).alias(k).expression for k, v in value.items()] - return cls(exp.Struct(expressions=columns)) - return cls(exp.convert(value)) - - @classmethod - def invoke_anonymous_function( - cls, column: t.Optional[ColumnOrLiteral], func_name: str, *args: t.Optional[ColumnOrLiteral] - ) -> Column: - columns = [] if column is None else [cls.ensure_col(column)] - column_args = [cls.ensure_col(arg) for arg in args] - expressions = [x.expression for x in columns + column_args] - new_expression = exp.Anonymous(this=func_name.upper(), expressions=expressions) - return Column(new_expression) - - @classmethod - def invoke_expression_over_column( - cls, column: t.Optional[ColumnOrLiteral], callable_expression: t.Callable, **kwargs - ) -> Column: - ensured_column = None if column is None else cls.ensure_col(column) - ensure_expression_values = { - k: [Column.ensure_col(x).expression for x in v] - if is_iterable(v) - else Column.ensure_col(v).expression - for k, v in kwargs.items() - if v is not None - } - new_expression = ( - callable_expression(**ensure_expression_values) - if ensured_column is None - else callable_expression( - this=ensured_column.column_expression, **ensure_expression_values - ) - ) - return Column(new_expression) - - def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column( - klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs) - ) - - def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column: - return Column( - klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs) - ) - - def unary_op(self, klass: t.Callable, **kwargs) -> Column: - return Column(klass(this=self.column_expression, **kwargs)) - - @property - def is_alias(self): - return isinstance(self.expression, exp.Alias) - - @property - def is_column(self): - return isinstance(self.expression, exp.Column) - - @property - def column_expression(self) -> t.Union[exp.Column, exp.Literal]: - return self.expression.unalias() - - @property - def alias_or_name(self) -> str: - return self.expression.alias_or_name - - @classmethod - def ensure_literal(cls, value) -> Column: - from sqlglot.dataframe.sql.functions import lit - - if isinstance(value, cls): - value = value.expression - if not isinstance(value, exp.Literal): - return lit(value) - return Column(value) - - def copy(self) -> Column: - return Column(self.expression.copy()) - - def set_table_name(self, table_name: str, copy=False) -> Column: - expression = self.expression.copy() if copy else self.expression - expression.set("table", exp.to_identifier(table_name)) - return Column(expression) - - def sql(self, **kwargs) -> str: - from sqlglot.dataframe.sql.session import SparkSession - - return self.expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) - - def alias(self, name: str) -> Column: - new_expression = exp.alias_(self.column_expression, name) - return Column(new_expression) - - def asc(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=True) - return Column(new_expression) - - def desc(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=False) - return Column(new_expression) - - asc_nulls_first = asc - - def asc_nulls_last(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=False, nulls_first=False) - return Column(new_expression) - - def desc_nulls_first(self) -> Column: - new_expression = exp.Ordered(this=self.column_expression, desc=True, nulls_first=True) - return Column(new_expression) - - desc_nulls_last = desc - - def when(self, condition: Column, value: t.Any) -> Column: - from sqlglot.dataframe.sql.functions import when - - column_with_if = when(condition, value) - if not isinstance(self.expression, exp.Case): - return column_with_if - new_column = self.copy() - new_column.expression.args["ifs"].extend(column_with_if.expression.args["ifs"]) - return new_column - - def otherwise(self, value: t.Any) -> Column: - from sqlglot.dataframe.sql.functions import lit - - true_value = value if isinstance(value, Column) else lit(value) - new_column = self.copy() - new_column.expression.set("default", true_value.column_expression) - return new_column - - def isNull(self) -> Column: - new_expression = exp.Is(this=self.column_expression, expression=exp.Null()) - return Column(new_expression) - - def isNotNull(self) -> Column: - new_expression = exp.Not(this=exp.Is(this=self.column_expression, expression=exp.Null())) - return Column(new_expression) - - def cast(self, dataType: t.Union[str, DataType]) -> Column: - """ - Functionality Difference: PySpark cast accepts a datatype instance of the datatype class - Sqlglot doesn't currently replicate this class so it only accepts a string - """ - from sqlglot.dataframe.sql.session import SparkSession - - if isinstance(dataType, DataType): - dataType = dataType.simpleString() - return Column(exp.cast(self.column_expression, dataType, dialect=SparkSession().dialect)) - - def startswith(self, value: t.Union[str, Column]) -> Column: - value = self._lit(value) if not isinstance(value, Column) else value - return self.invoke_anonymous_function(self, "STARTSWITH", value) - - def endswith(self, value: t.Union[str, Column]) -> Column: - value = self._lit(value) if not isinstance(value, Column) else value - return self.invoke_anonymous_function(self, "ENDSWITH", value) - - def rlike(self, regexp: str) -> Column: - return self.invoke_expression_over_column( - column=self, callable_expression=exp.RegexpLike, expression=self._lit(regexp).expression - ) - - def like(self, other: str): - return self.invoke_expression_over_column( - self, exp.Like, expression=self._lit(other).expression - ) - - def ilike(self, other: str): - return self.invoke_expression_over_column( - self, exp.ILike, expression=self._lit(other).expression - ) - - def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column: - startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos - length = self._lit(length) if not isinstance(length, Column) else length - return Column.invoke_expression_over_column( - self, exp.Substring, start=startPos.expression, length=length.expression - ) - - def isin(self, *cols: t.Union[ColumnOrLiteral, t.Iterable[ColumnOrLiteral]]): - columns = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore - expressions = [self._lit(x).expression for x in columns] - return Column.invoke_expression_over_column(self, exp.In, expressions=expressions) # type: ignore - - def between( - self, - lowerBound: t.Union[ColumnOrLiteral], - upperBound: t.Union[ColumnOrLiteral], - ) -> Column: - lower_bound_exp = ( - self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound - ) - upper_bound_exp = ( - self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound - ) - return Column( - exp.Between( - this=self.column_expression, - low=lower_bound_exp.expression, - high=upper_bound_exp.expression, - ) - ) - - def over(self, window: WindowSpec) -> Column: - window_expression = window.expression.copy() - window_expression.set("this", self.column_expression) - return Column(window_expression) diff --git a/altimate_packages/sqlglot/dataframe/sql/dataframe.py b/altimate_packages/sqlglot/dataframe/sql/dataframe.py deleted file mode 100644 index 450bb9933..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/dataframe.py +++ /dev/null @@ -1,866 +0,0 @@ -from __future__ import annotations - -import functools -import logging -import typing as t -import zlib -from copy import copy - -import sqlglot as sqlglot -from sqlglot import Dialect, expressions as exp -from sqlglot.dataframe.sql import functions as F -from sqlglot.dataframe.sql.column import Column -from sqlglot.dataframe.sql.group import GroupedData -from sqlglot.dataframe.sql.normalize import normalize -from sqlglot.dataframe.sql.operations import Operation, operation -from sqlglot.dataframe.sql.readwriter import DataFrameWriter -from sqlglot.dataframe.sql.transforms import replace_id_value -from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join -from sqlglot.dataframe.sql.window import Window -from sqlglot.helper import ensure_list, object_to_dict, seq_get -from sqlglot.optimizer import optimize as optimize_func -from sqlglot.optimizer.qualify_columns import quote_identifiers - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ( - ColumnLiterals, - ColumnOrLiteral, - ColumnOrName, - OutputExpressionContainer, - ) - from sqlglot.dataframe.sql.session import SparkSession - from sqlglot.dialects.dialect import DialectType - -logger = logging.getLogger("sqlglot") - -JOIN_HINTS = { - "BROADCAST", - "BROADCASTJOIN", - "MAPJOIN", - "MERGE", - "SHUFFLEMERGE", - "MERGEJOIN", - "SHUFFLE_HASH", - "SHUFFLE_REPLICATE_NL", -} - - -class DataFrame: - def __init__( - self, - spark: SparkSession, - expression: exp.Select, - branch_id: t.Optional[str] = None, - sequence_id: t.Optional[str] = None, - last_op: Operation = Operation.INIT, - pending_hints: t.Optional[t.List[exp.Expression]] = None, - output_expression_container: t.Optional[OutputExpressionContainer] = None, - **kwargs, - ): - self.spark = spark - self.expression = expression - self.branch_id = branch_id or self.spark._random_branch_id - self.sequence_id = sequence_id or self.spark._random_sequence_id - self.last_op = last_op - self.pending_hints = pending_hints or [] - self.output_expression_container = output_expression_container or exp.Select() - - def __getattr__(self, column_name: str) -> Column: - return self[column_name] - - def __getitem__(self, column_name: str) -> Column: - column_name = f"{self.branch_id}.{column_name}" - return Column(column_name) - - def __copy__(self): - return self.copy() - - @property - def sparkSession(self): - return self.spark - - @property - def write(self): - return DataFrameWriter(self) - - @property - def latest_cte_name(self) -> str: - if not self.expression.ctes: - from_exp = self.expression.args["from"] - if from_exp.alias_or_name: - return from_exp.alias_or_name - table_alias = from_exp.find(exp.TableAlias) - if not table_alias: - raise RuntimeError( - f"Could not find an alias name for this expression: {self.expression}" - ) - return table_alias.alias_or_name - return self.expression.ctes[-1].alias - - @property - def pending_join_hints(self): - return [hint for hint in self.pending_hints if isinstance(hint, exp.JoinHint)] - - @property - def pending_partition_hints(self): - return [hint for hint in self.pending_hints if isinstance(hint, exp.Anonymous)] - - @property - def columns(self) -> t.List[str]: - return self.expression.named_selects - - @property - def na(self) -> DataFrameNaFunctions: - return DataFrameNaFunctions(self) - - def _replace_cte_names_with_hashes(self, expression: exp.Select): - replacement_mapping = {} - for cte in expression.ctes: - old_name_id = cte.args["alias"].this - new_hashed_id = exp.to_identifier( - self._create_hash_from_expression(cte.this), quoted=old_name_id.args["quoted"] - ) - replacement_mapping[old_name_id] = new_hashed_id - expression = expression.transform(replace_id_value, replacement_mapping) - return expression - - def _create_cte_from_expression( - self, - expression: exp.Expression, - branch_id: t.Optional[str] = None, - sequence_id: t.Optional[str] = None, - **kwargs, - ) -> t.Tuple[exp.CTE, str]: - name = self._create_hash_from_expression(expression) - expression_to_cte = expression.copy() - expression_to_cte.set("with", None) - cte = exp.Select().with_(name, as_=expression_to_cte, **kwargs).ctes[0] - cte.set("branch_id", branch_id or self.branch_id) - cte.set("sequence_id", sequence_id or self.sequence_id) - return cte, name - - @t.overload - def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]: - ... - - @t.overload - def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]: - ... - - def _ensure_list_of_columns(self, cols): - return Column.ensure_cols(ensure_list(cols)) - - def _ensure_and_normalize_cols(self, cols, expression: t.Optional[exp.Select] = None): - cols = self._ensure_list_of_columns(cols) - normalize(self.spark, expression or self.expression, cols) - return cols - - def _ensure_and_normalize_col(self, col): - col = Column.ensure_col(col) - normalize(self.spark, self.expression, col) - return col - - def _convert_leaf_to_cte(self, sequence_id: t.Optional[str] = None) -> DataFrame: - df = self._resolve_pending_hints() - sequence_id = sequence_id or df.sequence_id - expression = df.expression.copy() - cte_expression, cte_name = df._create_cte_from_expression( - expression=expression, sequence_id=sequence_id - ) - new_expression = df._add_ctes_to_expression( - exp.Select(), expression.ctes + [cte_expression] - ) - sel_columns = df._get_outer_select_columns(cte_expression) - new_expression = new_expression.from_(cte_name).select( - *[x.alias_or_name for x in sel_columns] - ) - return df.copy(expression=new_expression, sequence_id=sequence_id) - - def _resolve_pending_hints(self) -> DataFrame: - df = self.copy() - if not self.pending_hints: - return df - expression = df.expression - hint_expression = expression.args.get("hint") or exp.Hint(expressions=[]) - for hint in df.pending_partition_hints: - hint_expression.append("expressions", hint) - df.pending_hints.remove(hint) - - join_aliases = { - join_table.alias_or_name - for join_table in get_tables_from_expression_with_join(expression) - } - if join_aliases: - for hint in df.pending_join_hints: - for sequence_id_expression in hint.expressions: - sequence_id_or_name = sequence_id_expression.alias_or_name - sequence_ids_to_match = [sequence_id_or_name] - if sequence_id_or_name in df.spark.name_to_sequence_id_mapping: - sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[ - sequence_id_or_name - ] - matching_ctes = [ - cte - for cte in reversed(expression.ctes) - if cte.args["sequence_id"] in sequence_ids_to_match - ] - for matching_cte in matching_ctes: - if matching_cte.alias_or_name in join_aliases: - sequence_id_expression.set("this", matching_cte.args["alias"].this) - df.pending_hints.remove(hint) - break - hint_expression.append("expressions", hint) - if hint_expression.expressions: - expression.set("hint", hint_expression) - return df - - def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame: - hint_name = hint_name.upper() - hint_expression = ( - exp.JoinHint( - this=hint_name, - expressions=[exp.to_table(parameter.alias_or_name) for parameter in args], - ) - if hint_name in JOIN_HINTS - else exp.Anonymous( - this=hint_name, expressions=[parameter.expression for parameter in args] - ) - ) - new_df = self.copy() - new_df.pending_hints.append(hint_expression) - return new_df - - def _set_operation(self, klass: t.Callable, other: DataFrame, distinct: bool): - other_df = other._convert_leaf_to_cte() - base_expression = self.expression.copy() - base_expression = self._add_ctes_to_expression(base_expression, other_df.expression.ctes) - all_ctes = base_expression.ctes - other_df.expression.set("with", None) - base_expression.set("with", None) - operation = klass(this=base_expression, distinct=distinct, expression=other_df.expression) - operation.set("with", exp.With(expressions=all_ctes)) - return self.copy(expression=operation)._convert_leaf_to_cte() - - def _cache(self, storage_level: str): - df = self._convert_leaf_to_cte() - df.expression.ctes[-1].set("cache_storage_level", storage_level) - return df - - @classmethod - def _add_ctes_to_expression(cls, expression: exp.Select, ctes: t.List[exp.CTE]) -> exp.Select: - expression = expression.copy() - with_expression = expression.args.get("with") - if with_expression: - existing_ctes = with_expression.expressions - existsing_cte_names = {x.alias_or_name for x in existing_ctes} - for cte in ctes: - if cte.alias_or_name not in existsing_cte_names: - existing_ctes.append(cte) - else: - existing_ctes = ctes - expression.set("with", exp.With(expressions=existing_ctes)) - return expression - - @classmethod - def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]: - expression = item.expression if isinstance(item, DataFrame) else item - return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions] - - @classmethod - def _create_hash_from_expression(cls, expression: exp.Expression) -> str: - from sqlglot.dataframe.sql.session import SparkSession - - value = expression.sql(dialect=SparkSession().dialect).encode("utf-8") - return f"t{zlib.crc32(value)}"[:6] - - def _get_select_expressions( - self, - ) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]: - select_expressions: t.List[ - t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select] - ] = [] - main_select_ctes: t.List[exp.CTE] = [] - for cte in self.expression.ctes: - cache_storage_level = cte.args.get("cache_storage_level") - if cache_storage_level: - select_expression = cte.this.copy() - select_expression.set("with", exp.With(expressions=copy(main_select_ctes))) - select_expression.set("cte_alias_name", cte.alias_or_name) - select_expression.set("cache_storage_level", cache_storage_level) - select_expressions.append((exp.Cache, select_expression)) - else: - main_select_ctes.append(cte) - main_select = self.expression.copy() - if main_select_ctes: - main_select.set("with", exp.With(expressions=main_select_ctes)) - expression_select_pair = (type(self.output_expression_container), main_select) - select_expressions.append(expression_select_pair) # type: ignore - return select_expressions - - def sql( - self, dialect: t.Optional[DialectType] = None, optimize: bool = True, **kwargs - ) -> t.List[str]: - from sqlglot.dataframe.sql.session import SparkSession - - if dialect and Dialect.get_or_raise(dialect)() != SparkSession().dialect: - logger.warning( - f"The recommended way of defining a dialect is by doing `SparkSession.builder.config('sqlframe.dialect', '{dialect}').getOrCreate()`. It is no longer needed then when calling `sql`. If you run into issues try updating your query to use this pattern." - ) - df = self._resolve_pending_hints() - select_expressions = df._get_select_expressions() - output_expressions: t.List[t.Union[exp.Select, exp.Cache, exp.Drop]] = [] - replacement_mapping: t.Dict[exp.Identifier, exp.Identifier] = {} - for expression_type, select_expression in select_expressions: - select_expression = select_expression.transform(replace_id_value, replacement_mapping) - if optimize: - quote_identifiers(select_expression) - select_expression = t.cast( - exp.Select, optimize_func(select_expression, dialect=SparkSession().dialect) - ) - select_expression = df._replace_cte_names_with_hashes(select_expression) - expression: t.Union[exp.Select, exp.Cache, exp.Drop] - if expression_type == exp.Cache: - cache_table_name = df._create_hash_from_expression(select_expression) - cache_table = exp.to_table(cache_table_name) - original_alias_name = select_expression.args["cte_alias_name"] - - replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore - cache_table_name - ) - sqlglot.schema.add_table( - cache_table_name, - { - expression.alias_or_name: expression.type.sql( - dialect=SparkSession().dialect - ) - for expression in select_expression.expressions - }, - dialect=SparkSession().dialect, - ) - cache_storage_level = select_expression.args["cache_storage_level"] - options = [ - exp.Literal.string("storageLevel"), - exp.Literal.string(cache_storage_level), - ] - expression = exp.Cache( - this=cache_table, expression=select_expression, lazy=True, options=options - ) - # We will drop the "view" if it exists before running the cache table - output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW")) - elif expression_type == exp.Create: - expression = df.output_expression_container.copy() - expression.set("expression", select_expression) - elif expression_type == exp.Insert: - expression = df.output_expression_container.copy() - select_without_ctes = select_expression.copy() - select_without_ctes.set("with", None) - expression.set("expression", select_without_ctes) - if select_expression.ctes: - expression.set("with", exp.With(expressions=select_expression.ctes)) - elif expression_type == exp.Select: - expression = select_expression - else: - raise ValueError(f"Invalid expression type: {expression_type}") - output_expressions.append(expression) - - return [ - expression.sql(**{"dialect": SparkSession().dialect, **kwargs}) - for expression in output_expressions - ] - - def copy(self, **kwargs) -> DataFrame: - return DataFrame(**object_to_dict(self, **kwargs)) - - @operation(Operation.SELECT) - def select(self, *cols, **kwargs) -> DataFrame: - cols = self._ensure_and_normalize_cols(cols) - kwargs["append"] = kwargs.get("append", False) - if self.expression.args.get("joins"): - ambiguous_cols = [ - col - for col in cols - if isinstance(col.column_expression, exp.Column) and not col.column_expression.table - ] - if ambiguous_cols: - join_table_identifiers = [ - x.this for x in get_tables_from_expression_with_join(self.expression) - ] - cte_names_in_join = [x.this for x in join_table_identifiers] - # If we have columns that resolve to multiple CTE expressions then we want to use each CTE left-to-right - # and therefore we allow multiple columns with the same name in the result. This matches the behavior - # of Spark. - resolved_column_position: t.Dict[Column, int] = {col: -1 for col in ambiguous_cols} - for ambiguous_col in ambiguous_cols: - ctes_with_column = [ - cte - for cte in self.expression.ctes - if cte.alias_or_name in cte_names_in_join - and ambiguous_col.alias_or_name in cte.this.named_selects - ] - # Check if there is a CTE with this column that we haven't used before. If so, use it. Otherwise, - # use the same CTE we used before - cte = seq_get(ctes_with_column, resolved_column_position[ambiguous_col] + 1) - if cte: - resolved_column_position[ambiguous_col] += 1 - else: - cte = ctes_with_column[resolved_column_position[ambiguous_col]] - ambiguous_col.expression.set("table", cte.alias_or_name) - return self.copy( - expression=self.expression.select(*[x.expression for x in cols], **kwargs), **kwargs - ) - - @operation(Operation.NO_OP) - def alias(self, name: str, **kwargs) -> DataFrame: - new_sequence_id = self.spark._random_sequence_id - df = self.copy() - for join_hint in df.pending_join_hints: - for expression in join_hint.expressions: - if expression.alias_or_name == self.sequence_id: - expression.set("this", Column.ensure_col(new_sequence_id).expression) - df.spark._add_alias_to_mapping(name, new_sequence_id) - return df._convert_leaf_to_cte(sequence_id=new_sequence_id) - - @operation(Operation.WHERE) - def where(self, column: t.Union[Column, bool], **kwargs) -> DataFrame: - col = self._ensure_and_normalize_col(column) - return self.copy(expression=self.expression.where(col.expression)) - - filter = where - - @operation(Operation.GROUP_BY) - def groupBy(self, *cols, **kwargs) -> GroupedData: - columns = self._ensure_and_normalize_cols(cols) - return GroupedData(self, columns, self.last_op) - - @operation(Operation.SELECT) - def agg(self, *exprs, **kwargs) -> DataFrame: - cols = self._ensure_and_normalize_cols(exprs) - return self.groupBy().agg(*cols) - - @operation(Operation.FROM) - def join( - self, - other_df: DataFrame, - on: t.Union[str, t.List[str], Column, t.List[Column]], - how: str = "inner", - **kwargs, - ) -> DataFrame: - other_df = other_df._convert_leaf_to_cte() - join_columns = self._ensure_list_of_columns(on) - # We will determine actual "join on" expression later so we don't provide it at first - join_expression = self.expression.join( - other_df.latest_cte_name, join_type=how.replace("_", " ") - ) - join_expression = self._add_ctes_to_expression(join_expression, other_df.expression.ctes) - self_columns = self._get_outer_select_columns(join_expression) - other_columns = self._get_outer_select_columns(other_df) - # Determines the join clause and select columns to be used passed on what type of columns were provided for - # the join. The columns returned changes based on how the on expression is provided. - if isinstance(join_columns[0].expression, exp.Column): - """ - Unique characteristics of join on column names only: - * The column names are put at the front of the select list - * The column names are deduplicated across the entire select list and only the column names (other dups are allowed) - """ - table_names = [ - table.alias_or_name - for table in get_tables_from_expression_with_join(join_expression) - ] - potential_ctes = [ - cte - for cte in join_expression.ctes - if cte.alias_or_name in table_names - and cte.alias_or_name != other_df.latest_cte_name - ] - # Determine the table to reference for the left side of the join by checking each of the left side - # tables and see if they have the column being referenced. - join_column_pairs = [] - for join_column in join_columns: - num_matching_ctes = 0 - for cte in potential_ctes: - if join_column.alias_or_name in cte.this.named_selects: - left_column = join_column.copy().set_table_name(cte.alias_or_name) - right_column = join_column.copy().set_table_name(other_df.latest_cte_name) - join_column_pairs.append((left_column, right_column)) - num_matching_ctes += 1 - if num_matching_ctes > 1: - raise ValueError( - f"Column {join_column.alias_or_name} is ambiguous. Please specify the table name." - ) - elif num_matching_ctes == 0: - raise ValueError( - f"Column {join_column.alias_or_name} does not exist in any of the tables." - ) - join_clause = functools.reduce( - lambda x, y: x & y, - [left_column == right_column for left_column, right_column in join_column_pairs], - ) - join_column_names = [left_col.alias_or_name for left_col, _ in join_column_pairs] - # To match spark behavior only the join clause gets deduplicated and it gets put in the front of the column list - select_column_names = [ - column.alias_or_name - if not isinstance(column.expression.this, exp.Star) - else column.sql() - for column in self_columns + other_columns - ] - select_column_names = [ - column_name - for column_name in select_column_names - if column_name not in join_column_names - ] - select_column_names = join_column_names + select_column_names - else: - """ - Unique characteristics of join on expressions: - * There is no deduplication of the results. - * The left join dataframe columns go first and right come after. No sort preference is given to join columns - """ - join_columns = self._ensure_and_normalize_cols(join_columns, join_expression) - if len(join_columns) > 1: - join_columns = [functools.reduce(lambda x, y: x & y, join_columns)] - join_clause = join_columns[0] - select_column_names = [column.alias_or_name for column in self_columns + other_columns] - - # Update the on expression with the actual join clause to replace the dummy one from before - join_expression.args["joins"][-1].set("on", join_clause.expression) - new_df = self.copy(expression=join_expression) - new_df.pending_join_hints.extend(self.pending_join_hints) - new_df.pending_hints.extend(other_df.pending_hints) - new_df = new_df.select.__wrapped__(new_df, *select_column_names) - return new_df - - @operation(Operation.ORDER_BY) - def orderBy( - self, - *cols: t.Union[str, Column], - ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None, - ) -> DataFrame: - """ - This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark - has irregular behavior and can result in runtime errors. Users shouldn't be mixing the two anyways so this - is unlikely to come up. - """ - columns = self._ensure_and_normalize_cols(cols) - pre_ordered_col_indexes = [ - x - for x in [ - i if isinstance(col.expression, exp.Ordered) else None - for i, col in enumerate(columns) - ] - if x is not None - ] - if ascending is None: - ascending = [True] * len(columns) - elif not isinstance(ascending, list): - ascending = [ascending] * len(columns) - ascending = [bool(x) for i, x in enumerate(ascending)] - assert len(columns) == len( - ascending - ), "The length of items in ascending must equal the number of columns provided" - col_and_ascending = list(zip(columns, ascending)) - order_by_columns = [ - exp.Ordered(this=col.expression, desc=not asc) - if i not in pre_ordered_col_indexes - else columns[i].column_expression - for i, (col, asc) in enumerate(col_and_ascending) - ] - return self.copy(expression=self.expression.order_by(*order_by_columns)) - - sort = orderBy - - @operation(Operation.FROM) - def union(self, other: DataFrame) -> DataFrame: - return self._set_operation(exp.Union, other, False) - - unionAll = union - - @operation(Operation.FROM) - def unionByName(self, other: DataFrame, allowMissingColumns: bool = False): - l_columns = self.columns - r_columns = other.columns - if not allowMissingColumns: - l_expressions = l_columns - r_expressions = l_columns - else: - l_expressions = [] - r_expressions = [] - r_columns_unused = copy(r_columns) - for l_column in l_columns: - l_expressions.append(l_column) - if l_column in r_columns: - r_expressions.append(l_column) - r_columns_unused.remove(l_column) - else: - r_expressions.append(exp.alias_(exp.Null(), l_column, copy=False)) - for r_column in r_columns_unused: - l_expressions.append(exp.alias_(exp.Null(), r_column, copy=False)) - r_expressions.append(r_column) - r_df = ( - other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions)) - ) - l_df = self.copy() - if allowMissingColumns: - l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions)) - return l_df._set_operation(exp.Union, r_df, False) - - @operation(Operation.FROM) - def intersect(self, other: DataFrame) -> DataFrame: - return self._set_operation(exp.Intersect, other, True) - - @operation(Operation.FROM) - def intersectAll(self, other: DataFrame) -> DataFrame: - return self._set_operation(exp.Intersect, other, False) - - @operation(Operation.FROM) - def exceptAll(self, other: DataFrame) -> DataFrame: - return self._set_operation(exp.Except, other, False) - - @operation(Operation.SELECT) - def distinct(self) -> DataFrame: - return self.copy(expression=self.expression.distinct()) - - @operation(Operation.SELECT) - def dropDuplicates(self, subset: t.Optional[t.List[str]] = None): - if not subset: - return self.distinct() - column_names = ensure_list(subset) - window = Window.partitionBy(*column_names).orderBy(*column_names) - return ( - self.copy() - .withColumn("row_num", F.row_number().over(window)) - .where(F.col("row_num") == F.lit(1)) - .drop("row_num") - ) - - @operation(Operation.FROM) - def dropna( - self, - how: str = "any", - thresh: t.Optional[int] = None, - subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, - ) -> DataFrame: - minimum_non_null = thresh or 0 # will be determined later if thresh is null - new_df = self.copy() - all_columns = self._get_outer_select_columns(new_df.expression) - if subset: - null_check_columns = self._ensure_and_normalize_cols(subset) - else: - null_check_columns = all_columns - if thresh is None: - minimum_num_nulls = 1 if how == "any" else len(null_check_columns) - else: - minimum_num_nulls = len(null_check_columns) - minimum_non_null + 1 - if minimum_num_nulls > len(null_check_columns): - raise RuntimeError( - f"The minimum num nulls for dropna must be less than or equal to the number of columns. " - f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}" - ) - if_null_checks = [ - F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns - ] - nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks) - num_nulls = nulls_added_together.alias("num_nulls") - new_df = new_df.select(num_nulls, append=True) - filtered_df = new_df.where(F.col("num_nulls") < F.lit(minimum_num_nulls)) - final_df = filtered_df.select(*all_columns) - return final_df - - @operation(Operation.FROM) - def fillna( - self, - value: t.Union[ColumnLiterals], - subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, - ) -> DataFrame: - """ - Functionality Difference: If you provide a value to replace a null and that type conflicts - with the type of the column then PySpark will just ignore your replacement. - This will try to cast them to be the same in some cases. So they won't always match. - Best to not mix types so make sure replacement is the same type as the column - - Possibility for improvement: Use `typeof` function to get the type of the column - and check if it matches the type of the value provided. If not then make it null. - """ - from sqlglot.dataframe.sql.functions import lit - - values = None - columns = None - new_df = self.copy() - all_columns = self._get_outer_select_columns(new_df.expression) - all_column_mapping = {column.alias_or_name: column for column in all_columns} - if isinstance(value, dict): - values = list(value.values()) - columns = self._ensure_and_normalize_cols(list(value)) - if not columns: - columns = self._ensure_and_normalize_cols(subset) if subset else all_columns - if not values: - values = [value] * len(columns) - value_columns = [lit(value) for value in values] - - null_replacement_mapping = { - column.alias_or_name: ( - F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name) - ) - for column, value in zip(columns, value_columns) - } - null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping} - null_replacement_columns = [ - null_replacement_mapping[column.alias_or_name] for column in all_columns - ] - new_df = new_df.select(*null_replacement_columns) - return new_df - - @operation(Operation.FROM) - def replace( - self, - to_replace: t.Union[bool, int, float, str, t.List, t.Dict], - value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, - subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None, - ) -> DataFrame: - from sqlglot.dataframe.sql.functions import lit - - old_values = None - new_df = self.copy() - all_columns = self._get_outer_select_columns(new_df.expression) - all_column_mapping = {column.alias_or_name: column for column in all_columns} - - columns = self._ensure_and_normalize_cols(subset) if subset else all_columns - if isinstance(to_replace, dict): - old_values = list(to_replace) - new_values = list(to_replace.values()) - elif not old_values and isinstance(to_replace, list): - assert isinstance(value, list), "value must be a list since the replacements are a list" - assert len(to_replace) == len( - value - ), "the replacements and values must be the same length" - old_values = to_replace - new_values = value - else: - old_values = [to_replace] * len(columns) - new_values = [value] * len(columns) - old_values = [lit(value) for value in old_values] - new_values = [lit(value) for value in new_values] - - replacement_mapping = {} - for column in columns: - expression = Column(None) - for i, (old_value, new_value) in enumerate(zip(old_values, new_values)): - if i == 0: - expression = F.when(column == old_value, new_value) - else: - expression = expression.when(column == old_value, new_value) # type: ignore - replacement_mapping[column.alias_or_name] = expression.otherwise(column).alias( - column.expression.alias_or_name - ) - - replacement_mapping = {**all_column_mapping, **replacement_mapping} - replacement_columns = [replacement_mapping[column.alias_or_name] for column in all_columns] - new_df = new_df.select(*replacement_columns) - return new_df - - @operation(Operation.SELECT) - def withColumn(self, colName: str, col: Column) -> DataFrame: - col = self._ensure_and_normalize_col(col) - existing_col_names = self.expression.named_selects - existing_col_index = ( - existing_col_names.index(colName) if colName in existing_col_names else None - ) - if existing_col_index: - expression = self.expression.copy() - expression.expressions[existing_col_index] = col.expression - return self.copy(expression=expression) - return self.copy().select(col.alias(colName), append=True) - - @operation(Operation.SELECT) - def withColumnRenamed(self, existing: str, new: str): - expression = self.expression.copy() - existing_columns = [ - expression - for expression in expression.expressions - if expression.alias_or_name == existing - ] - if not existing_columns: - raise ValueError("Tried to rename a column that doesn't exist") - for existing_column in existing_columns: - if isinstance(existing_column, exp.Column): - existing_column.replace(exp.alias_(existing_column, new)) - else: - existing_column.set("alias", exp.to_identifier(new)) - return self.copy(expression=expression) - - @operation(Operation.SELECT) - def drop(self, *cols: t.Union[str, Column]) -> DataFrame: - all_columns = self._get_outer_select_columns(self.expression) - drop_cols = self._ensure_and_normalize_cols(cols) - new_columns = [ - col - for col in all_columns - if col.alias_or_name not in [drop_column.alias_or_name for drop_column in drop_cols] - ] - return self.copy().select(*new_columns, append=False) - - @operation(Operation.LIMIT) - def limit(self, num: int) -> DataFrame: - return self.copy(expression=self.expression.limit(num)) - - @operation(Operation.NO_OP) - def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame: - parameter_list = ensure_list(parameters) - parameter_columns = ( - self._ensure_list_of_columns(parameter_list) - if parameters - else Column.ensure_cols([self.sequence_id]) - ) - return self._hint(name, parameter_columns) - - @operation(Operation.NO_OP) - def repartition( - self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName - ) -> DataFrame: - num_partition_cols = self._ensure_list_of_columns(numPartitions) - columns = self._ensure_and_normalize_cols(cols) - args = num_partition_cols + columns - return self._hint("repartition", args) - - @operation(Operation.NO_OP) - def coalesce(self, numPartitions: int) -> DataFrame: - num_partitions = Column.ensure_cols([numPartitions]) - return self._hint("coalesce", num_partitions) - - @operation(Operation.NO_OP) - def cache(self) -> DataFrame: - return self._cache(storage_level="MEMORY_AND_DISK") - - @operation(Operation.NO_OP) - def persist(self, storageLevel: str = "MEMORY_AND_DISK_SER") -> DataFrame: - """ - Storage Level Options: https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-aux-cache-cache-table.html - """ - return self._cache(storageLevel) - - -class DataFrameNaFunctions: - def __init__(self, df: DataFrame): - self.df = df - - def drop( - self, - how: str = "any", - thresh: t.Optional[int] = None, - subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, - ) -> DataFrame: - return self.df.dropna(how=how, thresh=thresh, subset=subset) - - def fill( - self, - value: t.Union[int, bool, float, str, t.Dict[str, t.Any]], - subset: t.Optional[t.Union[str, t.Tuple[str, ...], t.List[str]]] = None, - ) -> DataFrame: - return self.df.fillna(value=value, subset=subset) - - def replace( - self, - to_replace: t.Union[bool, int, float, str, t.List, t.Dict], - value: t.Optional[t.Union[bool, int, float, str, t.List]] = None, - subset: t.Optional[t.Union[str, t.List[str]]] = None, - ) -> DataFrame: - return self.df.replace(to_replace=to_replace, value=value, subset=subset) diff --git a/altimate_packages/sqlglot/dataframe/sql/functions.py b/altimate_packages/sqlglot/dataframe/sql/functions.py deleted file mode 100644 index d0ae50cc6..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/functions.py +++ /dev/null @@ -1,1267 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp as expression -from sqlglot.dataframe.sql.column import Column -from sqlglot.helper import ensure_list, flatten as _flatten - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnOrLiteral, ColumnOrName - from sqlglot.dataframe.sql.dataframe import DataFrame - - -def col(column_name: t.Union[ColumnOrName, t.Any]) -> Column: - return Column(column_name) - - -def lit(value: t.Optional[t.Any] = None) -> Column: - if isinstance(value, str): - return Column(expression.Literal.string(str(value))) - return Column(value) - - -def greatest(*cols: ColumnOrName) -> Column: - if len(cols) > 1: - return Column.invoke_expression_over_column( - cols[0], expression.Greatest, expressions=cols[1:] - ) - return Column.invoke_expression_over_column(cols[0], expression.Greatest) - - -def least(*cols: ColumnOrName) -> Column: - if len(cols) > 1: - return Column.invoke_expression_over_column(cols[0], expression.Least, expressions=cols[1:]) - return Column.invoke_expression_over_column(cols[0], expression.Least) - - -def count_distinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: - columns = [Column.ensure_col(x) for x in [col] + list(cols)] - return Column( - expression.Count(this=expression.Distinct(expressions=[x.expression for x in columns])) - ) - - -def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column: - return count_distinct(col, *cols) - - -def when(condition: Column, value: t.Any) -> Column: - true_value = value if isinstance(value, Column) else lit(value) - return Column( - expression.Case( - ifs=[expression.If(this=condition.column_expression, true=true_value.column_expression)] - ) - ) - - -def asc(col: ColumnOrName) -> Column: - return Column.ensure_col(col).asc() - - -def desc(col: ColumnOrName): - return Column.ensure_col(col).desc() - - -def broadcast(df: DataFrame) -> DataFrame: - return df.hint("broadcast") - - -def sqrt(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Sqrt) - - -def abs(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Abs) - - -def max(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Max) - - -def min(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Min) - - -def max_by(col: ColumnOrName, ord: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MAX_BY", ord) - - -def min_by(col: ColumnOrName, ord: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MIN_BY", ord) - - -def count(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Count) - - -def sum(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Sum) - - -def avg(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Avg) - - -def mean(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MEAN") - - -def sumDistinct(col: ColumnOrName) -> Column: - return sum_distinct(col) - - -def sum_distinct(col: ColumnOrName) -> Column: - raise NotImplementedError("Sum distinct is not currently implemented") - - -def product(col: ColumnOrName) -> Column: - raise NotImplementedError("Product is not currently implemented") - - -def acos(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ACOS") - - -def acosh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ACOSH") - - -def asin(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ASIN") - - -def asinh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ASINH") - - -def atan(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ATAN") - - -def atan2(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: - return Column.invoke_anonymous_function(col1, "ATAN2", col2) - - -def atanh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ATANH") - - -def cbrt(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "CBRT") - - -def ceil(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Ceil) - - -def cos(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "COS") - - -def cosh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "COSH") - - -def cot(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "COT") - - -def csc(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "CSC") - - -def exp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Exp) - - -def expm1(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "EXPM1") - - -def floor(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Floor) - - -def log10(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Log10) - - -def log1p(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "LOG1P") - - -def log2(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Log2) - - -def log(arg1: t.Union[ColumnOrName, float], arg2: t.Optional[ColumnOrName] = None) -> Column: - if arg2 is None: - return Column.invoke_expression_over_column(arg1, expression.Ln) - return Column.invoke_expression_over_column(arg1, expression.Log, expression=arg2) - - -def rint(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "RINT") - - -def sec(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SEC") - - -def signum(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SIGNUM") - - -def sin(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SIN") - - -def sinh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SINH") - - -def tan(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "TAN") - - -def tanh(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "TANH") - - -def toDegrees(col: ColumnOrName) -> Column: - return degrees(col) - - -def degrees(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "DEGREES") - - -def toRadians(col: ColumnOrName) -> Column: - return radians(col) - - -def radians(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "RADIANS") - - -def bitwiseNOT(col: ColumnOrName) -> Column: - return bitwise_not(col) - - -def bitwise_not(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.BitwiseNot) - - -def asc_nulls_first(col: ColumnOrName) -> Column: - return Column.ensure_col(col).asc_nulls_first() - - -def asc_nulls_last(col: ColumnOrName) -> Column: - return Column.ensure_col(col).asc_nulls_last() - - -def desc_nulls_first(col: ColumnOrName) -> Column: - return Column.ensure_col(col).desc_nulls_first() - - -def desc_nulls_last(col: ColumnOrName) -> Column: - return Column.ensure_col(col).desc_nulls_last() - - -def stddev(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Stddev) - - -def stddev_samp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.StddevSamp) - - -def stddev_pop(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.StddevPop) - - -def variance(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Variance) - - -def var_samp(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Variance) - - -def var_pop(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.VariancePop) - - -def skewness(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SKEWNESS") - - -def kurtosis(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "KURTOSIS") - - -def collect_list(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.ArrayAgg) - - -def collect_set(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.SetAgg) - - -def hypot(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: - return Column.invoke_anonymous_function(col1, "HYPOT", col2) - - -def pow(col1: t.Union[ColumnOrName, float], col2: t.Union[ColumnOrName, float]) -> Column: - return Column.invoke_expression_over_column(col1, expression.Pow, expression=col2) - - -def row_number() -> Column: - return Column(expression.Anonymous(this="ROW_NUMBER")) - - -def dense_rank() -> Column: - return Column(expression.Anonymous(this="DENSE_RANK")) - - -def rank() -> Column: - return Column(expression.Anonymous(this="RANK")) - - -def cume_dist() -> Column: - return Column(expression.Anonymous(this="CUME_DIST")) - - -def percent_rank() -> Column: - return Column(expression.Anonymous(this="PERCENT_RANK")) - - -def approxCountDistinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: - return approx_count_distinct(col, rsd) - - -def approx_count_distinct(col: ColumnOrName, rsd: t.Optional[float] = None) -> Column: - if rsd is None: - return Column.invoke_expression_over_column(col, expression.ApproxDistinct) - return Column.invoke_expression_over_column(col, expression.ApproxDistinct, accuracy=rsd) - - -def coalesce(*cols: ColumnOrName) -> Column: - if len(cols) > 1: - return Column.invoke_expression_over_column( - cols[0], expression.Coalesce, expressions=cols[1:] - ) - return Column.invoke_expression_over_column(cols[0], expression.Coalesce) - - -def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "CORR", col2) - - -def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "COVAR_POP", col2) - - -def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2) - - -def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - return Column.invoke_expression_over_column(col, expression.First, ignore_nulls=ignorenulls) - - -def grouping_id(*cols: ColumnOrName) -> Column: - if not cols: - return Column.invoke_anonymous_function(None, "GROUPING_ID") - if len(cols) == 1: - return Column.invoke_anonymous_function(cols[0], "GROUPING_ID") - return Column.invoke_anonymous_function(cols[0], "GROUPING_ID", *cols[1:]) - - -def input_file_name() -> Column: - return Column.invoke_anonymous_function(None, "INPUT_FILE_NAME") - - -def isnan(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.IsNan) - - -def isnull(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ISNULL") - - -def last(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: - return Column.invoke_expression_over_column(col, expression.Last, ignore_nulls=ignorenulls) - - -def monotonically_increasing_id() -> Column: - return Column.invoke_anonymous_function(None, "MONOTONICALLY_INCREASING_ID") - - -def nanvl(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "NANVL", col2) - - -def percentile_approx( - col: ColumnOrName, - percentage: t.Union[ColumnOrLiteral, t.List[float], t.Tuple[float]], - accuracy: t.Optional[t.Union[ColumnOrLiteral, int]] = None, -) -> Column: - if accuracy: - return Column.invoke_expression_over_column( - col, expression.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy - ) - return Column.invoke_expression_over_column( - col, expression.ApproxQuantile, quantile=lit(percentage) - ) - - -def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: - return Column.invoke_anonymous_function(seed, "RAND") - - -def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column: - return Column.invoke_anonymous_function(seed, "RANDN") - - -def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: - if scale is not None: - return Column.invoke_expression_over_column(col, expression.Round, decimals=scale) - return Column.invoke_expression_over_column(col, expression.Round) - - -def bround(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: - if scale is not None: - return Column.invoke_anonymous_function(col, "BROUND", scale) - return Column.invoke_anonymous_function(col, "BROUND") - - -def shiftleft(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_expression_over_column( - col, expression.BitwiseLeftShift, expression=numBits - ) - - -def shiftLeft(col: ColumnOrName, numBits: int) -> Column: - return shiftleft(col, numBits) - - -def shiftright(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_expression_over_column( - col, expression.BitwiseRightShift, expression=numBits - ) - - -def shiftRight(col: ColumnOrName, numBits: int) -> Column: - return shiftright(col, numBits) - - -def shiftrightunsigned(col: ColumnOrName, numBits: int) -> Column: - return Column.invoke_anonymous_function(col, "SHIFTRIGHTUNSIGNED", numBits) - - -def shiftRightUnsigned(col: ColumnOrName, numBits: int) -> Column: - return shiftrightunsigned(col, numBits) - - -def expr(str: str) -> Column: - return Column(str) - - -def struct(col: t.Union[ColumnOrName, t.Iterable[ColumnOrName]], *cols: ColumnOrName) -> Column: - columns = ensure_list(col) + list(cols) - return Column.invoke_expression_over_column(None, expression.Struct, expressions=columns) - - -def conv(col: ColumnOrName, fromBase: int, toBase: int) -> Column: - return Column.invoke_anonymous_function(col, "CONV", fromBase, toBase) - - -def factorial(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "FACTORIAL") - - -def lag( - col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None -) -> Column: - if default is not None: - return Column.invoke_anonymous_function(col, "LAG", offset, default) - if offset != 1: - return Column.invoke_anonymous_function(col, "LAG", offset) - return Column.invoke_anonymous_function(col, "LAG") - - -def lead( - col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None -) -> Column: - if default is not None: - return Column.invoke_anonymous_function(col, "LEAD", offset, default) - if offset != 1: - return Column.invoke_anonymous_function(col, "LEAD", offset) - return Column.invoke_anonymous_function(col, "LEAD") - - -def nth_value( - col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None -) -> Column: - if ignoreNulls is not None: - raise NotImplementedError("There is currently not support for `ignoreNulls` parameter") - if offset != 1: - return Column.invoke_anonymous_function(col, "NTH_VALUE", offset) - return Column.invoke_anonymous_function(col, "NTH_VALUE") - - -def ntile(n: int) -> Column: - return Column.invoke_anonymous_function(None, "NTILE", n) - - -def current_date() -> Column: - return Column.invoke_expression_over_column(None, expression.CurrentDate) - - -def current_timestamp() -> Column: - return Column.invoke_expression_over_column(None, expression.CurrentTimestamp) - - -def date_format(col: ColumnOrName, format: str) -> Column: - return Column.invoke_expression_over_column(col, expression.TimeToStr, format=lit(format)) - - -def year(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Year) - - -def quarter(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "QUARTER") - - -def month(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Month) - - -def dayofweek(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.DayOfWeek) - - -def dayofmonth(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.DayOfMonth) - - -def dayofyear(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.DayOfYear) - - -def hour(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "HOUR") - - -def minute(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MINUTE") - - -def second(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SECOND") - - -def weekofyear(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.WeekOfYear) - - -def make_date(year: ColumnOrName, month: ColumnOrName, day: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(year, "MAKE_DATE", month, day) - - -def date_add(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column( - col, expression.DateAdd, expression=days, unit=expression.Var(this="day") - ) - - -def date_sub(col: ColumnOrName, days: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_expression_over_column( - col, expression.DateSub, expression=days, unit=expression.Var(this="day") - ) - - -def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(end, expression.DateDiff, expression=start) - - -def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column: - return Column.invoke_anonymous_function(start, "ADD_MONTHS", months) - - -def months_between( - date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None -) -> Column: - if roundOff is None: - return Column.invoke_expression_over_column( - date1, expression.MonthsBetween, expression=date2 - ) - - return Column.invoke_expression_over_column( - date1, expression.MonthsBetween, expression=date2, roundoff=roundOff - ) - - -def to_date(col: ColumnOrName, format: t.Optional[str] = None) -> Column: - if format is not None: - return Column.invoke_expression_over_column( - col, expression.TsOrDsToDate, format=lit(format) - ) - return Column.invoke_expression_over_column(col, expression.TsOrDsToDate) - - -def to_timestamp(col: ColumnOrName, format: t.Optional[str] = None) -> Column: - if format is not None: - return Column.invoke_expression_over_column(col, expression.StrToTime, format=lit(format)) - - return Column.ensure_col(col).cast("timestamp") - - -def trunc(col: ColumnOrName, format: str) -> Column: - return Column.invoke_expression_over_column(col, expression.DateTrunc, unit=lit(format)) - - -def date_trunc(format: str, timestamp: ColumnOrName) -> Column: - return Column.invoke_expression_over_column( - timestamp, expression.TimestampTrunc, unit=lit(format) - ) - - -def next_day(col: ColumnOrName, dayOfWeek: str) -> Column: - return Column.invoke_anonymous_function(col, "NEXT_DAY", lit(dayOfWeek)) - - -def last_day(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "LAST_DAY") - - -def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column: - if format is not None: - return Column.invoke_expression_over_column(col, expression.UnixToStr, format=lit(format)) - return Column.invoke_expression_over_column(col, expression.UnixToStr) - - -def unix_timestamp( - timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None -) -> Column: - if format is not None: - return Column.invoke_expression_over_column( - timestamp, expression.StrToUnix, format=lit(format) - ) - return Column.invoke_expression_over_column(timestamp, expression.StrToUnix) - - -def from_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: - tz_column = tz if isinstance(tz, Column) else lit(tz) - return Column.invoke_anonymous_function(timestamp, "FROM_UTC_TIMESTAMP", tz_column) - - -def to_utc_timestamp(timestamp: ColumnOrName, tz: ColumnOrName) -> Column: - tz_column = tz if isinstance(tz, Column) else lit(tz) - return Column.invoke_anonymous_function(timestamp, "TO_UTC_TIMESTAMP", tz_column) - - -def timestamp_seconds(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "TIMESTAMP_SECONDS") - - -def window( - timeColumn: ColumnOrName, - windowDuration: str, - slideDuration: t.Optional[str] = None, - startTime: t.Optional[str] = None, -) -> Column: - if slideDuration is not None and startTime is not None: - return Column.invoke_anonymous_function( - timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime) - ) - if slideDuration is not None: - return Column.invoke_anonymous_function( - timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration) - ) - if startTime is not None: - return Column.invoke_anonymous_function( - timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime) - ) - return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration)) - - -def session_window(timeColumn: ColumnOrName, gapDuration: ColumnOrName) -> Column: - gap_duration_column = gapDuration if isinstance(gapDuration, Column) else lit(gapDuration) - return Column.invoke_anonymous_function(timeColumn, "SESSION_WINDOW", gap_duration_column) - - -def crc32(col: ColumnOrName) -> Column: - column = col if isinstance(col, Column) else lit(col) - return Column.invoke_anonymous_function(column, "CRC32") - - -def md5(col: ColumnOrName) -> Column: - column = col if isinstance(col, Column) else lit(col) - return Column.invoke_expression_over_column(column, expression.MD5) - - -def sha1(col: ColumnOrName) -> Column: - column = col if isinstance(col, Column) else lit(col) - return Column.invoke_expression_over_column(column, expression.SHA) - - -def sha2(col: ColumnOrName, numBits: int) -> Column: - column = col if isinstance(col, Column) else lit(col) - return Column.invoke_expression_over_column(column, expression.SHA2, length=lit(numBits)) - - -def hash(*cols: ColumnOrName) -> Column: - args = cols[1:] if len(cols) > 1 else [] - return Column.invoke_anonymous_function(cols[0], "HASH", *args) - - -def xxhash64(*cols: ColumnOrName) -> Column: - args = cols[1:] if len(cols) > 1 else [] - return Column.invoke_anonymous_function(cols[0], "XXHASH64", *args) - - -def assert_true(col: ColumnOrName, errorMsg: t.Optional[ColumnOrName] = None) -> Column: - if errorMsg is not None: - error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg) - return Column.invoke_anonymous_function(col, "ASSERT_TRUE", error_msg_col) - return Column.invoke_anonymous_function(col, "ASSERT_TRUE") - - -def raise_error(errorMsg: ColumnOrName) -> Column: - error_msg_col = errorMsg if isinstance(errorMsg, Column) else lit(errorMsg) - return Column.invoke_anonymous_function(error_msg_col, "RAISE_ERROR") - - -def upper(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Upper) - - -def lower(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Lower) - - -def ascii(col: ColumnOrLiteral) -> Column: - return Column.invoke_anonymous_function(col, "ASCII") - - -def base64(col: ColumnOrLiteral) -> Column: - return Column.invoke_expression_over_column(col, expression.ToBase64) - - -def unbase64(col: ColumnOrLiteral) -> Column: - return Column.invoke_expression_over_column(col, expression.FromBase64) - - -def ltrim(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "LTRIM") - - -def rtrim(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "RTRIM") - - -def trim(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Trim) - - -def concat_ws(sep: str, *cols: ColumnOrName) -> Column: - return Column.invoke_expression_over_column( - None, expression.ConcatWs, expressions=[lit(sep)] + list(cols) - ) - - -def decode(col: ColumnOrName, charset: str) -> Column: - return Column.invoke_expression_over_column( - col, expression.Decode, charset=expression.Literal.string(charset) - ) - - -def encode(col: ColumnOrName, charset: str) -> Column: - return Column.invoke_expression_over_column( - col, expression.Encode, charset=expression.Literal.string(charset) - ) - - -def format_number(col: ColumnOrName, d: int) -> Column: - return Column.invoke_anonymous_function(col, "FORMAT_NUMBER", lit(d)) - - -def format_string(format: str, *cols: ColumnOrName) -> Column: - format_col = lit(format) - columns = [Column.ensure_col(x) for x in cols] - return Column.invoke_anonymous_function(format_col, "FORMAT_STRING", *columns) - - -def instr(col: ColumnOrName, substr: str) -> Column: - return Column.invoke_anonymous_function(col, "INSTR", lit(substr)) - - -def overlay( - src: ColumnOrName, - replace: ColumnOrName, - pos: t.Union[ColumnOrName, int], - len: t.Optional[t.Union[ColumnOrName, int]] = None, -) -> Column: - if len is not None: - return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos, len) - return Column.invoke_anonymous_function(src, "OVERLAY", replace, pos) - - -def sentences( - string: ColumnOrName, - language: t.Optional[ColumnOrName] = None, - country: t.Optional[ColumnOrName] = None, -) -> Column: - if language is not None and country is not None: - return Column.invoke_anonymous_function(string, "SENTENCES", language, country) - if language is not None: - return Column.invoke_anonymous_function(string, "SENTENCES", language) - if country is not None: - return Column.invoke_anonymous_function(string, "SENTENCES", lit("en"), country) - return Column.invoke_anonymous_function(string, "SENTENCES") - - -def substring(str: ColumnOrName, pos: int, len: int) -> Column: - return Column.ensure_col(str).substr(pos, len) - - -def substring_index(str: ColumnOrName, delim: str, count: int) -> Column: - return Column.invoke_anonymous_function(str, "SUBSTRING_INDEX", lit(delim), lit(count)) - - -def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(left, expression.Levenshtein, expression=right) - - -def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column: - substr_col = lit(substr) - if pos is not None: - return Column.invoke_expression_over_column( - str, expression.StrPosition, substr=substr_col, position=pos - ) - return Column.invoke_expression_over_column(str, expression.StrPosition, substr=substr_col) - - -def lpad(col: ColumnOrName, len: int, pad: str) -> Column: - return Column.invoke_anonymous_function(col, "LPAD", lit(len), lit(pad)) - - -def rpad(col: ColumnOrName, len: int, pad: str) -> Column: - return Column.invoke_anonymous_function(col, "RPAD", lit(len), lit(pad)) - - -def repeat(col: ColumnOrName, n: int) -> Column: - return Column.invoke_expression_over_column(col, expression.Repeat, times=lit(n)) - - -def split(str: ColumnOrName, pattern: str, limit: t.Optional[int] = None) -> Column: - if limit is not None: - return Column.invoke_expression_over_column( - str, expression.RegexpSplit, expression=lit(pattern).expression, limit=limit - ) - return Column.invoke_expression_over_column( - str, expression.RegexpSplit, expression=lit(pattern) - ) - - -def regexp_extract(str: ColumnOrName, pattern: str, idx: t.Optional[int] = None) -> Column: - return Column.invoke_expression_over_column( - str, - expression.RegexpExtract, - expression=lit(pattern), - group=idx, - ) - - -def regexp_replace( - str: ColumnOrName, pattern: str, replacement: str, position: t.Optional[int] = None -) -> Column: - return Column.invoke_expression_over_column( - str, - expression.RegexpReplace, - expression=lit(pattern), - replacement=lit(replacement), - position=position, - ) - - -def initcap(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Initcap) - - -def soundex(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SOUNDEX") - - -def bin(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "BIN") - - -def hex(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Hex) - - -def unhex(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Unhex) - - -def length(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Length) - - -def octet_length(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "OCTET_LENGTH") - - -def bit_length(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "BIT_LENGTH") - - -def translate(srcCol: ColumnOrName, matching: str, replace: str) -> Column: - return Column.invoke_anonymous_function(srcCol, "TRANSLATE", lit(matching), lit(replace)) - - -def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: - columns = _flatten(cols) if not isinstance(cols[0], (str, Column)) else cols - return Column.invoke_expression_over_column(None, expression.Array, expressions=columns) - - -def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: - cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore - return Column.invoke_expression_over_column( - None, - expression.VarMap, - keys=array(*cols[::2]).expression, - values=array(*cols[1::2]).expression, - ) - - -def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, expression.Map, keys=col1, values=col2) - - -def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column: - value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_expression_over_column( - col, expression.ArrayContains, expression=value_col.expression - ) - - -def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2)) - - -def slice( - x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int] -) -> Column: - start_col = start if isinstance(start, Column) else lit(start) - length_col = length if isinstance(length, Column) else lit(length) - return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col) - - -def array_join( - col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None -) -> Column: - if null_replacement is not None: - return Column.invoke_expression_over_column( - col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement) - ) - return Column.invoke_expression_over_column( - col, expression.ArrayJoin, expression=lit(delimiter) - ) - - -def concat(*cols: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(None, expression.Concat, expressions=cols) - - -def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column: - value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_anonymous_function(col, "ARRAY_POSITION", value_col) - - -def element_at(col: ColumnOrName, value: ColumnOrLiteral) -> Column: - value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_anonymous_function(col, "ELEMENT_AT", value_col) - - -def array_remove(col: ColumnOrName, value: ColumnOrLiteral) -> Column: - value_col = value if isinstance(value, Column) else lit(value) - return Column.invoke_anonymous_function(col, "ARRAY_REMOVE", value_col) - - -def array_distinct(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ARRAY_DISTINCT") - - -def array_intersect(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "ARRAY_INTERSECT", Column.ensure_col(col2)) - - -def array_union(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "ARRAY_UNION", Column.ensure_col(col2)) - - -def array_except(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "ARRAY_EXCEPT", Column.ensure_col(col2)) - - -def explode(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Explode) - - -def posexplode(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.Posexplode) - - -def explode_outer(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "EXPLODE_OUTER") - - -def posexplode_outer(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "POSEXPLODE_OUTER") - - -def get_json_object(col: ColumnOrName, path: str) -> Column: - return Column.invoke_expression_over_column(col, expression.JSONExtract, path=lit(path)) - - -def json_tuple(col: ColumnOrName, *fields: str) -> Column: - return Column.invoke_anonymous_function(col, "JSON_TUPLE", *[lit(field) for field in fields]) - - -def from_json( - col: ColumnOrName, - schema: t.Union[Column, str], - options: t.Optional[t.Dict[str, str]] = None, -) -> Column: - schema = schema if isinstance(schema, Column) else lit(schema) - if options is not None: - options_col = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_anonymous_function(col, "FROM_JSON", schema, options_col) - return Column.invoke_anonymous_function(col, "FROM_JSON", schema) - - -def to_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: - if options is not None: - options_col = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_expression_over_column(col, expression.JSONFormat, options=options_col) - return Column.invoke_expression_over_column(col, expression.JSONFormat) - - -def schema_of_json(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: - if options is not None: - options_col = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON", options_col) - return Column.invoke_anonymous_function(col, "SCHEMA_OF_JSON") - - -def schema_of_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: - if options is not None: - options_col = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV", options_col) - return Column.invoke_anonymous_function(col, "SCHEMA_OF_CSV") - - -def to_csv(col: ColumnOrName, options: t.Optional[t.Dict[str, str]] = None) -> Column: - if options is not None: - options_col = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_anonymous_function(col, "TO_CSV", options_col) - return Column.invoke_anonymous_function(col, "TO_CSV") - - -def size(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.ArraySize) - - -def array_min(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ARRAY_MIN") - - -def array_max(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "ARRAY_MAX") - - -def sort_array(col: ColumnOrName, asc: t.Optional[bool] = None) -> Column: - if asc is not None: - return Column.invoke_expression_over_column(col, expression.SortArray, asc=asc) - return Column.invoke_expression_over_column(col, expression.SortArray) - - -def array_sort( - col: ColumnOrName, - comparator: t.Optional[t.Union[t.Callable[[Column, Column], Column]]] = None, -) -> Column: - if comparator is not None: - f_expression = _get_lambda_from_func(comparator) - return Column.invoke_expression_over_column( - col, expression.ArraySort, expression=f_expression - ) - return Column.invoke_expression_over_column(col, expression.ArraySort) - - -def shuffle(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "SHUFFLE") - - -def reverse(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "REVERSE") - - -def flatten(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "FLATTEN") - - -def map_keys(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MAP_KEYS") - - -def map_values(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MAP_VALUES") - - -def map_entries(col: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col, "MAP_ENTRIES") - - -def map_from_entries(col: ColumnOrName) -> Column: - return Column.invoke_expression_over_column(col, expression.MapFromEntries) - - -def array_repeat(col: ColumnOrName, count: t.Union[ColumnOrName, int]) -> Column: - count_col = count if isinstance(count, Column) else lit(count) - return Column.invoke_anonymous_function(col, "ARRAY_REPEAT", count_col) - - -def array_zip(*cols: ColumnOrName) -> Column: - if len(cols) == 1: - return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP") - return Column.invoke_anonymous_function(cols[0], "ARRAY_ZIP", *cols[1:]) - - -def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column: - columns = list(flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore - if len(columns) == 1: - return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT") - return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:]) - - -def sequence( - start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None -) -> Column: - if step is not None: - return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step) - return Column.invoke_anonymous_function(start, "SEQUENCE", stop) - - -def from_csv( - col: ColumnOrName, - schema: t.Union[Column, str], - options: t.Optional[t.Dict[str, str]] = None, -) -> Column: - schema = schema if isinstance(schema, Column) else lit(schema) - if options is not None: - option_cols = create_map([lit(x) for x in _flatten(options.items())]) - return Column.invoke_anonymous_function(col, "FROM_CSV", schema, option_cols) - return Column.invoke_anonymous_function(col, "FROM_CSV", schema) - - -def aggregate( - col: ColumnOrName, - initialValue: ColumnOrName, - merge: t.Callable[[Column, Column], Column], - finish: t.Optional[t.Callable[[Column], Column]] = None, -) -> Column: - merge_exp = _get_lambda_from_func(merge) - if finish is not None: - finish_exp = _get_lambda_from_func(finish) - return Column.invoke_expression_over_column( - col, - expression.Reduce, - initial=initialValue, - merge=Column(merge_exp), - finish=Column(finish_exp), - ) - return Column.invoke_expression_over_column( - col, expression.Reduce, initial=initialValue, merge=Column(merge_exp) - ) - - -def transform( - col: ColumnOrName, - f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], -) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_expression_over_column( - col, expression.Transform, expression=Column(f_expression) - ) - - -def exists(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "EXISTS", Column(f_expression)) - - -def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression)) - - -def filter( - col: ColumnOrName, - f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]], -) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_expression_over_column( - col, expression.ArrayFilter, expression=f_expression - ) - - -def zip_with( - left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column] -) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression)) - - -def transform_keys(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "TRANSFORM_KEYS", Column(f_expression)) - - -def transform_values(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "TRANSFORM_VALUES", Column(f_expression)) - - -def map_filter(col: ColumnOrName, f: t.Union[t.Callable[[Column, Column], Column]]) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col, "MAP_FILTER", Column(f_expression)) - - -def map_zip_with( - col1: ColumnOrName, - col2: ColumnOrName, - f: t.Union[t.Callable[[Column, Column, Column], Column]], -) -> Column: - f_expression = _get_lambda_from_func(f) - return Column.invoke_anonymous_function(col1, "MAP_ZIP_WITH", col2, Column(f_expression)) - - -def _lambda_quoted(value: str) -> t.Optional[bool]: - return False if value == "_" else None - - -def _get_lambda_from_func(lambda_expression: t.Callable): - variables = [ - expression.to_identifier(x, quoted=_lambda_quoted(x)) - for x in lambda_expression.__code__.co_varnames - ] - return expression.Lambda( - this=lambda_expression(*[Column(x) for x in variables]).expression, - expressions=variables, - ) diff --git a/altimate_packages/sqlglot/dataframe/sql/group.py b/altimate_packages/sqlglot/dataframe/sql/group.py deleted file mode 100644 index ba27c170d..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/group.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot.dataframe.sql import functions as F -from sqlglot.dataframe.sql.column import Column -from sqlglot.dataframe.sql.operations import Operation, operation - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql.dataframe import DataFrame - - -class GroupedData: - def __init__(self, df: DataFrame, group_by_cols: t.List[Column], last_op: Operation): - self._df = df.copy() - self.spark = df.spark - self.last_op = last_op - self.group_by_cols = group_by_cols - - def _get_function_applied_columns( - self, func_name: str, cols: t.Tuple[str, ...] - ) -> t.List[Column]: - func_name = func_name.lower() - return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols] - - @operation(Operation.SELECT) - def agg(self, *exprs: t.Union[Column, t.Dict[str, str]]) -> DataFrame: - columns = ( - [Column(f"{agg_func}({column_name})") for column_name, agg_func in exprs[0].items()] - if isinstance(exprs[0], dict) - else exprs - ) - cols = self._df._ensure_and_normalize_cols(columns) - - expression = self._df.expression.group_by( - *[x.expression for x in self.group_by_cols] - ).select(*[x.expression for x in self.group_by_cols + cols], append=False) - return self._df.copy(expression=expression) - - def count(self) -> DataFrame: - return self.agg(F.count("*").alias("count")) - - def mean(self, *cols: str) -> DataFrame: - return self.avg(*cols) - - def avg(self, *cols: str) -> DataFrame: - return self.agg(*self._get_function_applied_columns("avg", cols)) - - def max(self, *cols: str) -> DataFrame: - return self.agg(*self._get_function_applied_columns("max", cols)) - - def min(self, *cols: str) -> DataFrame: - return self.agg(*self._get_function_applied_columns("min", cols)) - - def sum(self, *cols: str) -> DataFrame: - return self.agg(*self._get_function_applied_columns("sum", cols)) - - def pivot(self, *cols: str) -> DataFrame: - raise NotImplementedError("Sum distinct is not currently implemented") diff --git a/altimate_packages/sqlglot/dataframe/sql/normalize.py b/altimate_packages/sqlglot/dataframe/sql/normalize.py deleted file mode 100644 index f68bacb2f..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/normalize.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import expressions as exp -from sqlglot.dataframe.sql.column import Column -from sqlglot.dataframe.sql.util import get_tables_from_expression_with_join -from sqlglot.helper import ensure_list - -NORMALIZE_INPUT = t.TypeVar("NORMALIZE_INPUT", bound=t.Union[str, exp.Expression, Column]) - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql.session import SparkSession - - -def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[NORMALIZE_INPUT]): - expr = ensure_list(expr) - expressions = _ensure_expressions(expr) - for expression in expressions: - identifiers = expression.find_all(exp.Identifier) - for identifier in identifiers: - identifier.transform(spark.dialect.normalize_identifier) - replace_alias_name_with_cte_name(spark, expression_context, identifier) - replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier) - - -def replace_alias_name_with_cte_name( - spark: SparkSession, expression_context: exp.Select, id: exp.Identifier -): - if id.alias_or_name in spark.name_to_sequence_id_mapping: - for cte in reversed(expression_context.ctes): - if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]: - _set_alias_name(id, cte.alias_or_name) - break - - -def replace_branch_and_sequence_ids_with_cte_name( - spark: SparkSession, expression_context: exp.Select, id: exp.Identifier -): - if id.alias_or_name in spark.known_ids: - # Check if we have a join and if both the tables in that join share a common branch id - # If so we need to have this reference the left table by default unless the id is a sequence - # id then it keeps that reference. This handles the weird edge case in spark that shouldn't - # be common in practice - if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids: - join_table_aliases = [ - x.alias_or_name for x in get_tables_from_expression_with_join(expression_context) - ] - ctes_in_join = [ - cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases - ] - if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]: - assert len(ctes_in_join) == 2 - _set_alias_name(id, ctes_in_join[0].alias_or_name) - return - - for cte in reversed(expression_context.ctes): - if id.alias_or_name in (cte.args["branch_id"], cte.args["sequence_id"]): - _set_alias_name(id, cte.alias_or_name) - return - - -def _set_alias_name(id: exp.Identifier, name: str): - id.set("this", name) - - -def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]: - results = [] - for value in values: - if isinstance(value, str): - results.append(Column.ensure_col(value).expression) - elif isinstance(value, Column): - results.append(value.expression) - elif isinstance(value, exp.Expression): - results.append(value) - else: - raise ValueError(f"Got an invalid type to normalize: {type(value)}") - return results diff --git a/altimate_packages/sqlglot/dataframe/sql/operations.py b/altimate_packages/sqlglot/dataframe/sql/operations.py deleted file mode 100644 index e4c106b3b..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/operations.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -import functools -import typing as t -from enum import IntEnum - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql.dataframe import DataFrame - from sqlglot.dataframe.sql.group import GroupedData - - -class Operation(IntEnum): - INIT = -1 - NO_OP = 0 - FROM = 1 - WHERE = 2 - GROUP_BY = 3 - HAVING = 4 - SELECT = 5 - ORDER_BY = 6 - LIMIT = 7 - - -def operation(op: Operation): - """ - Decorator used around DataFrame methods to indicate what type of operation is being performed from the - ordered Operation enums. This is used to determine which operations should be performed on a CTE vs. - included with the previous operation. - - Ex: After a user does a join we want to allow them to select which columns for the different - tables that they want to carry through to the following operation. If we put that join in - a CTE preemptively then the user would not have a chance to select which column they want - in cases where there is overlap in names. - """ - - def decorator(func: t.Callable): - @functools.wraps(func) - def wrapper(self: DataFrame, *args, **kwargs): - if self.last_op == Operation.INIT: - self = self._convert_leaf_to_cte() - self.last_op = Operation.NO_OP - last_op = self.last_op - new_op = op if op != Operation.NO_OP else last_op - if new_op < last_op or (last_op == new_op == Operation.SELECT): - self = self._convert_leaf_to_cte() - df: t.Union[DataFrame, GroupedData] = func(self, *args, **kwargs) - df.last_op = new_op # type: ignore - return df - - wrapper.__wrapped__ = func # type: ignore - return wrapper - - return decorator diff --git a/altimate_packages/sqlglot/dataframe/sql/readwriter.py b/altimate_packages/sqlglot/dataframe/sql/readwriter.py deleted file mode 100644 index 481cb9f3a..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/readwriter.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -import typing as t - -import sqlglot as sqlglot -from sqlglot import expressions as exp -from sqlglot.helper import object_to_dict - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql.dataframe import DataFrame - from sqlglot.dataframe.sql.session import SparkSession - - -class DataFrameReader: - def __init__(self, spark: SparkSession): - self.spark = spark - - def table(self, tableName: str) -> DataFrame: - from sqlglot.dataframe.sql.dataframe import DataFrame - from sqlglot.dataframe.sql.session import SparkSession - - sqlglot.schema.add_table(tableName, dialect=SparkSession().dialect) - - return DataFrame( - self.spark, - exp.Select() - .from_( - exp.to_table(tableName, dialect=SparkSession().dialect).transform( - SparkSession().dialect.normalize_identifier - ) - ) - .select( - *( - column - for column in sqlglot.schema.column_names( - tableName, dialect=SparkSession().dialect - ) - ) - ), - ) - - -class DataFrameWriter: - def __init__( - self, - df: DataFrame, - spark: t.Optional[SparkSession] = None, - mode: t.Optional[str] = None, - by_name: bool = False, - ): - self._df = df - self._spark = spark or df.spark - self._mode = mode - self._by_name = by_name - - def copy(self, **kwargs) -> DataFrameWriter: - return DataFrameWriter( - **{ - k[1:] if k.startswith("_") else k: v - for k, v in object_to_dict(self, **kwargs).items() - } - ) - - def sql(self, **kwargs) -> t.List[str]: - return self._df.sql(**kwargs) - - def mode(self, saveMode: t.Optional[str]) -> DataFrameWriter: - return self.copy(_mode=saveMode) - - @property - def byName(self): - return self.copy(by_name=True) - - def insertInto(self, tableName: str, overwrite: t.Optional[bool] = None) -> DataFrameWriter: - from sqlglot.dataframe.sql.session import SparkSession - - output_expression_container = exp.Insert( - **{ - "this": exp.to_table(tableName), - "overwrite": overwrite, - } - ) - df = self._df.copy(output_expression_container=output_expression_container) - if self._by_name: - columns = sqlglot.schema.column_names( - tableName, only_visible=True, dialect=SparkSession().dialect - ) - df = df._convert_leaf_to_cte().select(*columns) - - return self.copy(_df=df) - - def saveAsTable(self, name: str, format: t.Optional[str] = None, mode: t.Optional[str] = None): - if format is not None: - raise NotImplementedError("Providing Format in the save as table is not supported") - exists, replace, mode = None, None, mode or str(self._mode) - if mode == "append": - return self.insertInto(name) - if mode == "ignore": - exists = True - if mode == "overwrite": - replace = True - output_expression_container = exp.Create( - this=exp.to_table(name), - kind="TABLE", - exists=exists, - replace=replace, - ) - return self.copy(_df=self._df.copy(output_expression_container=output_expression_container)) diff --git a/altimate_packages/sqlglot/dataframe/sql/session.py b/altimate_packages/sqlglot/dataframe/sql/session.py deleted file mode 100644 index 0d1607976..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/session.py +++ /dev/null @@ -1,190 +0,0 @@ -from __future__ import annotations - -import typing as t -import uuid -from collections import defaultdict - -import sqlglot as sqlglot -from sqlglot import Dialect, expressions as exp -from sqlglot.dataframe.sql import functions as F -from sqlglot.dataframe.sql.dataframe import DataFrame -from sqlglot.dataframe.sql.readwriter import DataFrameReader -from sqlglot.dataframe.sql.types import StructType -from sqlglot.dataframe.sql.util import get_column_mapping_from_schema_input -from sqlglot.helper import classproperty - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnLiterals, SchemaInput - - -class SparkSession: - DEFAULT_DIALECT = "spark" - _instance = None - - def __init__(self): - if not hasattr(self, "known_ids"): - self.known_ids = set() - self.known_branch_ids = set() - self.known_sequence_ids = set() - self.name_to_sequence_id_mapping = defaultdict(list) - self.incrementing_id = 1 - self.dialect = Dialect.get_or_raise(self.DEFAULT_DIALECT)() - - def __new__(cls, *args, **kwargs) -> SparkSession: - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - - @property - def read(self) -> DataFrameReader: - return DataFrameReader(self) - - def table(self, tableName: str) -> DataFrame: - return self.read.table(tableName) - - def createDataFrame( - self, - data: t.Sequence[t.Union[t.Dict[str, ColumnLiterals], t.List[ColumnLiterals], t.Tuple]], - schema: t.Optional[SchemaInput] = None, - samplingRatio: t.Optional[float] = None, - verifySchema: bool = False, - ) -> DataFrame: - from sqlglot.dataframe.sql.dataframe import DataFrame - - if samplingRatio is not None or verifySchema: - raise NotImplementedError("Sampling Ratio and Verify Schema are not supported") - if schema is not None and ( - not isinstance(schema, (StructType, str, list)) - or (isinstance(schema, list) and not isinstance(schema[0], str)) - ): - raise NotImplementedError("Only schema of either list or string of list supported") - if not data: - raise ValueError("Must provide data to create into a DataFrame") - - column_mapping: t.Dict[str, t.Optional[str]] - if schema is not None: - column_mapping = get_column_mapping_from_schema_input(schema) - elif isinstance(data[0], dict): - column_mapping = {col_name.strip(): None for col_name in data[0]} - else: - column_mapping = {f"_{i}": None for i in range(1, len(data[0]) + 1)} - - data_expressions = [ - exp.Tuple( - expressions=list( - map( - lambda x: F.lit(x).expression, - row if not isinstance(row, dict) else row.values(), - ) - ) - ) - for row in data - ] - - sel_columns = [ - F.col(name).cast(data_type).alias(name).expression - if data_type is not None - else F.col(name).expression - for name, data_type in column_mapping.items() - ] - - select_kwargs = { - "expressions": sel_columns, - "from": exp.From( - this=exp.Values( - expressions=data_expressions, - alias=exp.TableAlias( - this=exp.to_identifier(self._auto_incrementing_name), - columns=[exp.to_identifier(col_name) for col_name in column_mapping], - ), - ), - ), - } - - sel_expression = exp.Select(**select_kwargs) - return DataFrame(self, sel_expression) - - def sql(self, sqlQuery: str) -> DataFrame: - expression = sqlglot.parse_one(sqlQuery, read=self.dialect) - if isinstance(expression, exp.Select): - df = DataFrame(self, expression) - df = df._convert_leaf_to_cte() - elif isinstance(expression, (exp.Create, exp.Insert)): - select_expression = expression.expression.copy() - if isinstance(expression, exp.Insert): - select_expression.set("with", expression.args.get("with")) - expression.set("with", None) - del expression.args["expression"] - df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore - df = df._convert_leaf_to_cte() - else: - raise ValueError( - "Unknown expression type provided in the SQL. Please create an issue with the SQL." - ) - return df - - @property - def _auto_incrementing_name(self) -> str: - name = f"a{self.incrementing_id}" - self.incrementing_id += 1 - return name - - @property - def _random_branch_id(self) -> str: - id = self._random_id - self.known_branch_ids.add(id) - return id - - @property - def _random_sequence_id(self): - id = self._random_id - self.known_sequence_ids.add(id) - return id - - @property - def _random_id(self) -> str: - id = "r" + uuid.uuid4().hex - self.known_ids.add(id) - return id - - @property - def _join_hint_names(self) -> t.Set[str]: - return {"BROADCAST", "MERGE", "SHUFFLE_HASH", "SHUFFLE_REPLICATE_NL"} - - def _add_alias_to_mapping(self, name: str, sequence_id: str): - self.name_to_sequence_id_mapping[name].append(sequence_id) - - class Builder: - SQLFRAME_DIALECT_KEY = "sqlframe.dialect" - - def __init__(self): - self.dialect = "spark" - - def __getattr__(self, item) -> SparkSession.Builder: - return self - - def __call__(self, *args, **kwargs): - return self - - def config( - self, - key: t.Optional[str] = None, - value: t.Optional[t.Any] = None, - *, - map: t.Optional[t.Dict[str, t.Any]] = None, - **kwargs: t.Any, - ) -> SparkSession.Builder: - if key == self.SQLFRAME_DIALECT_KEY: - self.dialect = value - elif map and self.SQLFRAME_DIALECT_KEY in map: - self.dialect = map[self.SQLFRAME_DIALECT_KEY] - return self - - def getOrCreate(self) -> SparkSession: - spark = SparkSession() - spark.dialect = Dialect.get_or_raise(self.dialect)() - return spark - - @classproperty - def builder(cls) -> Builder: - return cls.Builder() diff --git a/altimate_packages/sqlglot/dataframe/sql/transforms.py b/altimate_packages/sqlglot/dataframe/sql/transforms.py deleted file mode 100644 index b3dcc121d..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/transforms.py +++ /dev/null @@ -1,9 +0,0 @@ -import typing as t - -from sqlglot import expressions as exp - - -def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]): - if isinstance(node, exp.Identifier) and node in replacement_mapping: - node = node.replace(replacement_mapping[node].copy()) - return node diff --git a/altimate_packages/sqlglot/dataframe/sql/types.py b/altimate_packages/sqlglot/dataframe/sql/types.py deleted file mode 100644 index a63e50504..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/types.py +++ /dev/null @@ -1,212 +0,0 @@ -import typing as t - - -class DataType: - def __repr__(self) -> str: - return self.__class__.__name__ + "()" - - def __hash__(self) -> int: - return hash(str(self)) - - def __eq__(self, other: t.Any) -> bool: - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - def __ne__(self, other: t.Any) -> bool: - return not self.__eq__(other) - - def __str__(self) -> str: - return self.typeName() - - @classmethod - def typeName(cls) -> str: - return cls.__name__[:-4].lower() - - def simpleString(self) -> str: - return str(self) - - def jsonValue(self) -> t.Union[str, t.Dict[str, t.Any]]: - return str(self) - - -class DataTypeWithLength(DataType): - def __init__(self, length: int): - self.length = length - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.length})" - - def __str__(self) -> str: - return f"{self.typeName()}({self.length})" - - -class StringType(DataType): - pass - - -class CharType(DataTypeWithLength): - pass - - -class VarcharType(DataTypeWithLength): - pass - - -class BinaryType(DataType): - pass - - -class BooleanType(DataType): - pass - - -class DateType(DataType): - pass - - -class TimestampType(DataType): - pass - - -class TimestampNTZType(DataType): - @classmethod - def typeName(cls) -> str: - return "timestamp_ntz" - - -class DecimalType(DataType): - def __init__(self, precision: int = 10, scale: int = 0): - self.precision = precision - self.scale = scale - - def simpleString(self) -> str: - return f"decimal({self.precision}, {self.scale})" - - def jsonValue(self) -> str: - return f"decimal({self.precision}, {self.scale})" - - def __repr__(self) -> str: - return f"DecimalType({self.precision}, {self.scale})" - - -class DoubleType(DataType): - pass - - -class FloatType(DataType): - pass - - -class ByteType(DataType): - def __str__(self) -> str: - return "tinyint" - - -class IntegerType(DataType): - def __str__(self) -> str: - return "int" - - -class LongType(DataType): - def __str__(self) -> str: - return "bigint" - - -class ShortType(DataType): - def __str__(self) -> str: - return "smallint" - - -class ArrayType(DataType): - def __init__(self, elementType: DataType, containsNull: bool = True): - self.elementType = elementType - self.containsNull = containsNull - - def __repr__(self) -> str: - return f"ArrayType({self.elementType, str(self.containsNull)}" - - def simpleString(self) -> str: - return f"array<{self.elementType.simpleString()}>" - - def jsonValue(self) -> t.Dict[str, t.Any]: - return { - "type": self.typeName(), - "elementType": self.elementType.jsonValue(), - "containsNull": self.containsNull, - } - - -class MapType(DataType): - def __init__(self, keyType: DataType, valueType: DataType, valueContainsNull: bool = True): - self.keyType = keyType - self.valueType = valueType - self.valueContainsNull = valueContainsNull - - def __repr__(self) -> str: - return f"MapType({self.keyType}, {self.valueType}, {str(self.valueContainsNull)})" - - def simpleString(self) -> str: - return f"map<{self.keyType.simpleString()}, {self.valueType.simpleString()}>" - - def jsonValue(self) -> t.Dict[str, t.Any]: - return { - "type": self.typeName(), - "keyType": self.keyType.jsonValue(), - "valueType": self.valueType.jsonValue(), - "valueContainsNull": self.valueContainsNull, - } - - -class StructField(DataType): - def __init__( - self, - name: str, - dataType: DataType, - nullable: bool = True, - metadata: t.Optional[t.Dict[str, t.Any]] = None, - ): - self.name = name - self.dataType = dataType - self.nullable = nullable - self.metadata = metadata or {} - - def __repr__(self) -> str: - return f"StructField('{self.name}', {self.dataType}, {str(self.nullable)})" - - def simpleString(self) -> str: - return f"{self.name}:{self.dataType.simpleString()}" - - def jsonValue(self) -> t.Dict[str, t.Any]: - return { - "name": self.name, - "type": self.dataType.jsonValue(), - "nullable": self.nullable, - "metadata": self.metadata, - } - - -class StructType(DataType): - def __init__(self, fields: t.Optional[t.List[StructField]] = None): - if not fields: - self.fields = [] - self.names = [] - else: - self.fields = fields - self.names = [f.name for f in fields] - - def __iter__(self) -> t.Iterator[StructField]: - return iter(self.fields) - - def __len__(self) -> int: - return len(self.fields) - - def __repr__(self) -> str: - return f"StructType({', '.join(str(field) for field in self)})" - - def simpleString(self) -> str: - return f"struct<{', '.join(x.simpleString() for x in self)}>" - - def jsonValue(self) -> t.Dict[str, t.Any]: - return {"type": self.typeName(), "fields": [x.jsonValue() for x in self]} - - def fieldNames(self) -> t.List[str]: - return list(self.names) diff --git a/altimate_packages/sqlglot/dataframe/sql/util.py b/altimate_packages/sqlglot/dataframe/sql/util.py deleted file mode 100644 index 4b9fbb187..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/util.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import expressions as exp -from sqlglot.dataframe.sql import types - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import SchemaInput - - -def get_column_mapping_from_schema_input(schema: SchemaInput) -> t.Dict[str, t.Optional[str]]: - if isinstance(schema, dict): - return schema - elif isinstance(schema, str): - col_name_type_strs = [x.strip() for x in schema.split(",")] - return { - name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() - for name_type_str in col_name_type_strs - } - elif isinstance(schema, types.StructType): - return {struct_field.name: struct_field.dataType.simpleString() for struct_field in schema} - return {x.strip(): None for x in schema} # type: ignore - - -def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.Table]: - if not expression.args.get("joins"): - return [] - - left_table = expression.args["from"].this - other_tables = [join.this for join in expression.args["joins"]] - return [left_table] + other_tables diff --git a/altimate_packages/sqlglot/dataframe/sql/window.py b/altimate_packages/sqlglot/dataframe/sql/window.py deleted file mode 100644 index c1d913fe0..000000000 --- a/altimate_packages/sqlglot/dataframe/sql/window.py +++ /dev/null @@ -1,134 +0,0 @@ -from __future__ import annotations - -import sys -import typing as t - -from sqlglot import expressions as exp -from sqlglot.dataframe.sql import functions as F -from sqlglot.helper import flatten - -if t.TYPE_CHECKING: - from sqlglot.dataframe.sql._typing import ColumnOrName - - -class Window: - _JAVA_MIN_LONG = -(1 << 63) # -9223372036854775808 - _JAVA_MAX_LONG = (1 << 63) - 1 # 9223372036854775807 - _PRECEDING_THRESHOLD = max(-sys.maxsize, _JAVA_MIN_LONG) - _FOLLOWING_THRESHOLD = min(sys.maxsize, _JAVA_MAX_LONG) - - unboundedPreceding: int = _JAVA_MIN_LONG - - unboundedFollowing: int = _JAVA_MAX_LONG - - currentRow: int = 0 - - @classmethod - def partitionBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: - return WindowSpec().partitionBy(*cols) - - @classmethod - def orderBy(cls, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: - return WindowSpec().orderBy(*cols) - - @classmethod - def rowsBetween(cls, start: int, end: int) -> WindowSpec: - return WindowSpec().rowsBetween(start, end) - - @classmethod - def rangeBetween(cls, start: int, end: int) -> WindowSpec: - return WindowSpec().rangeBetween(start, end) - - -class WindowSpec: - def __init__(self, expression: exp.Expression = exp.Window()): - self.expression = expression - - def copy(self): - return WindowSpec(self.expression.copy()) - - def sql(self, **kwargs) -> str: - from sqlglot.dataframe.sql.session import SparkSession - - return self.expression.sql(dialect=SparkSession().dialect, **kwargs) - - def partitionBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: - from sqlglot.dataframe.sql.column import Column - - cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore - expressions = [Column.ensure_col(x).expression for x in cols] - window_spec = self.copy() - partition_by_expressions = window_spec.expression.args.get("partition_by", []) - partition_by_expressions.extend(expressions) - window_spec.expression.set("partition_by", partition_by_expressions) - return window_spec - - def orderBy(self, *cols: t.Union[ColumnOrName, t.List[ColumnOrName]]) -> WindowSpec: - from sqlglot.dataframe.sql.column import Column - - cols = flatten(cols) if isinstance(cols[0], (list, set, tuple)) else cols # type: ignore - expressions = [Column.ensure_col(x).expression for x in cols] - window_spec = self.copy() - if window_spec.expression.args.get("order") is None: - window_spec.expression.set("order", exp.Order(expressions=[])) - order_by = window_spec.expression.args["order"].expressions - order_by.extend(expressions) - window_spec.expression.args["order"].set("expressions", order_by) - return window_spec - - def _calc_start_end( - self, start: int, end: int - ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]: - kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = { - "start_side": None, - "end_side": None, - } - if start == Window.currentRow: - kwargs["start"] = "CURRENT ROW" - else: - kwargs = { - **kwargs, - **{ - "start_side": "PRECEDING", - "start": "UNBOUNDED" - if start <= Window.unboundedPreceding - else F.lit(start).expression, - }, - } - if end == Window.currentRow: - kwargs["end"] = "CURRENT ROW" - else: - kwargs = { - **kwargs, - **{ - "end_side": "FOLLOWING", - "end": "UNBOUNDED" - if end >= Window.unboundedFollowing - else F.lit(end).expression, - }, - } - return kwargs - - def rowsBetween(self, start: int, end: int) -> WindowSpec: - window_spec = self.copy() - spec = self._calc_start_end(start, end) - spec["kind"] = "ROWS" - window_spec.expression.set( - "spec", - exp.WindowSpec( - **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} - ), - ) - return window_spec - - def rangeBetween(self, start: int, end: int) -> WindowSpec: - window_spec = self.copy() - spec = self._calc_start_end(start, end) - spec["kind"] = "RANGE" - window_spec.expression.set( - "spec", - exp.WindowSpec( - **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec} - ), - ) - return window_spec diff --git a/altimate_packages/sqlglot/dialects/__init__.py b/altimate_packages/sqlglot/dialects/__init__.py deleted file mode 100644 index 711496a76..000000000 --- a/altimate_packages/sqlglot/dialects/__init__.py +++ /dev/null @@ -1,118 +0,0 @@ -# ruff: noqa: F401 -""" -## Dialects - -While there is a SQL standard, most SQL engines support a variation of that standard. This makes it difficult -to write portable SQL code. SQLGlot bridges all the different variations, called "dialects", with an extensible -SQL transpilation framework. - -The base `sqlglot.dialects.dialect.Dialect` class implements a generic dialect that aims to be as universal as possible. - -Each SQL variation has its own `Dialect` subclass, extending the corresponding `Tokenizer`, `Parser` and `Generator` -classes as needed. - -### Implementing a custom Dialect - -Creating a new SQL dialect may seem complicated at first, but it is actually quite simple in SQLGlot: - -```python -from sqlglot import exp -from sqlglot.dialects.dialect import Dialect -from sqlglot.generator import Generator -from sqlglot.tokens import Tokenizer, TokenType - - -class Custom(Dialect): - class Tokenizer(Tokenizer): - QUOTES = ["'", '"'] # Strings can be delimited by either single or double quotes - IDENTIFIERS = ["`"] # Identifiers can be delimited by backticks - - # Associates certain meaningful words with tokens that capture their intent - KEYWORDS = { - **Tokenizer.KEYWORDS, - "INT64": TokenType.BIGINT, - "FLOAT64": TokenType.DOUBLE, - } - - class Generator(Generator): - # Specifies how AST nodes, i.e. subclasses of exp.Expression, should be converted into SQL - TRANSFORMS = { - exp.Array: lambda self, e: f"[{self.expressions(e)}]", - } - - # Specifies how AST nodes representing data types should be converted into SQL - TYPE_MAPPING = { - exp.DataType.Type.TINYINT: "INT64", - exp.DataType.Type.SMALLINT: "INT64", - exp.DataType.Type.INT: "INT64", - exp.DataType.Type.BIGINT: "INT64", - exp.DataType.Type.DECIMAL: "NUMERIC", - exp.DataType.Type.FLOAT: "FLOAT64", - exp.DataType.Type.DOUBLE: "FLOAT64", - exp.DataType.Type.BOOLEAN: "BOOL", - exp.DataType.Type.TEXT: "STRING", - } -``` - -The above example demonstrates how certain parts of the base `Dialect` class can be overridden to match a different -specification. Even though it is a fairly realistic starting point, we strongly encourage the reader to study existing -dialect implementations in order to understand how their various components can be modified, depending on the use-case. - ----- -""" - -import importlib -import threading - -DIALECTS = [ - "Athena", - "BigQuery", - "ClickHouse", - "Databricks", - "Doris", - "Drill", - "Druid", - "DuckDB", - "Dune", - "Hive", - "Materialize", - "MySQL", - "Oracle", - "Postgres", - "Presto", - "PRQL", - "Redshift", - "RisingWave", - "Snowflake", - "Spark", - "Spark2", - "SQLite", - "StarRocks", - "Tableau", - "Teradata", - "Trino", - "TSQL", -] - -MODULE_BY_DIALECT = {name: name.lower() for name in DIALECTS} -DIALECT_MODULE_NAMES = MODULE_BY_DIALECT.values() - -MODULE_BY_ATTRIBUTE = { - **MODULE_BY_DIALECT, - "Dialect": "dialect", - "Dialects": "dialect", -} - -__all__ = list(MODULE_BY_ATTRIBUTE) - -_import_lock = threading.Lock() - - -def __getattr__(name): - module_name = MODULE_BY_ATTRIBUTE.get(name) - if module_name: - with _import_lock: - module = importlib.import_module(f"sqlglot.dialects.{module_name}") - return getattr(module, name) - - raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/altimate_packages/sqlglot/dialects/athena.py b/altimate_packages/sqlglot/dialects/athena.py deleted file mode 100644 index e2aaa967c..000000000 --- a/altimate_packages/sqlglot/dialects/athena.py +++ /dev/null @@ -1,166 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp -from sqlglot.dialects.trino import Trino -from sqlglot.dialects.hive import Hive -from sqlglot.tokens import TokenType - - -def _generate_as_hive(expression: exp.Expression) -> bool: - if isinstance(expression, exp.Create): - if expression.kind == "TABLE": - properties: t.Optional[exp.Properties] = expression.args.get("properties") - if properties and properties.find(exp.ExternalProperty): - return True # CREATE EXTERNAL TABLE is Hive - - if not isinstance(expression.expression, exp.Query): - return True # any CREATE TABLE other than CREATE TABLE AS SELECT is Hive - else: - return expression.kind != "VIEW" # CREATE VIEW is never Hive but CREATE SCHEMA etc is - - # https://docs.aws.amazon.com/athena/latest/ug/ddl-reference.html - elif isinstance(expression, (exp.Alter, exp.Drop, exp.Describe)): - if isinstance(expression, exp.Drop) and expression.kind == "VIEW": - # DROP VIEW is Trino (I guess because CREATE VIEW is) - return False - - # Everything else is Hive - return True - - return False - - -def _is_iceberg_table(properties: exp.Properties) -> bool: - table_type_property = next( - ( - p - for p in properties.expressions - if isinstance(p, exp.Property) and p.name == "table_type" - ), - None, - ) - return bool(table_type_property and table_type_property.text("value").lower() == "iceberg") - - -def _location_property_sql(self: Athena.Generator, e: exp.LocationProperty): - # If table_type='iceberg', the LocationProperty is called 'location' - # Otherwise, it's called 'external_location' - # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html - - prop_name = "external_location" - - if isinstance(e.parent, exp.Properties): - if _is_iceberg_table(e.parent): - prop_name = "location" - - return f"{prop_name}={self.sql(e, 'this')}" - - -def _partitioned_by_property_sql(self: Athena.Generator, e: exp.PartitionedByProperty) -> str: - # If table_type='iceberg' then the table property for partitioning is called 'partitioning' - # If table_type='hive' it's called 'partitioned_by' - # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties - - prop_name = "partitioned_by" - if isinstance(e.parent, exp.Properties): - if _is_iceberg_table(e.parent): - prop_name = "partitioning" - - return f"{prop_name}={self.sql(e, 'this')}" - - -class Athena(Trino): - """ - Over the years, it looks like AWS has taken various execution engines, bolted on AWS-specific modifications and then - built the Athena service around them. - - Thus, Athena is not simply hosted Trino, it's more like a router that routes SQL queries to an execution engine depending - on the query type. - - As at 2024-09-10, assuming your Athena workgroup is configured to use "Athena engine version 3", the following engines exist: - - Hive: - - Accepts mostly the same syntax as Hadoop / Hive - - Uses backticks to quote identifiers - - Has a distinctive DDL syntax (around things like setting table properties, storage locations etc) that is different from Trino - - Used for *most* DDL, with some exceptions that get routed to the Trino engine instead: - - CREATE [EXTERNAL] TABLE (without AS SELECT) - - ALTER - - DROP - - Trino: - - Uses double quotes to quote identifiers - - Used for DDL operations that involve SELECT queries, eg: - - CREATE VIEW / DROP VIEW - - CREATE TABLE... AS SELECT - - Used for DML operations - - SELECT, INSERT, UPDATE, DELETE, MERGE - - The SQLGlot Athena dialect tries to identify which engine a query would be routed to and then uses the parser / generator for that engine - rather than trying to create a universal syntax that can handle both types. - """ - - class Tokenizer(Trino.Tokenizer): - """ - The Tokenizer is flexible enough to tokenize queries across both the Hive and Trino engines - """ - - IDENTIFIERS = ['"', "`"] - KEYWORDS = { - **Hive.Tokenizer.KEYWORDS, - **Trino.Tokenizer.KEYWORDS, - "UNLOAD": TokenType.COMMAND, - } - - class Parser(Trino.Parser): - """ - Parse queries for the Athena Trino execution engine - """ - - STATEMENT_PARSERS = { - **Trino.Parser.STATEMENT_PARSERS, - TokenType.USING: lambda self: self._parse_as_command(self._prev), - } - - class _HiveGenerator(Hive.Generator): - def alter_sql(self, expression: exp.Alter) -> str: - # package any ALTER TABLE ADD actions into a Schema object - # so it gets generated as `ALTER TABLE .. ADD COLUMNS(...)` - # instead of `ALTER TABLE ... ADD COLUMN` which is invalid syntax on Athena - if isinstance(expression, exp.Alter) and expression.kind == "TABLE": - if expression.actions and isinstance(expression.actions[0], exp.ColumnDef): - new_actions = exp.Schema(expressions=expression.actions) - expression.set("actions", [new_actions]) - - return super().alter_sql(expression) - - class Generator(Trino.Generator): - """ - Generate queries for the Athena Trino execution engine - """ - - PROPERTIES_LOCATION = { - **Trino.Generator.PROPERTIES_LOCATION, - exp.LocationProperty: exp.Properties.Location.POST_WITH, - } - - TRANSFORMS = { - **Trino.Generator.TRANSFORMS, - exp.PartitionedByProperty: _partitioned_by_property_sql, - exp.LocationProperty: _location_property_sql, - } - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - hive_kwargs = {**kwargs, "dialect": "hive"} - - self._hive_generator = Athena._HiveGenerator(*args, **hive_kwargs) - - def generate(self, expression: exp.Expression, copy: bool = True) -> str: - if _generate_as_hive(expression): - return self._hive_generator.generate(expression, copy) - - return super().generate(expression, copy) diff --git a/altimate_packages/sqlglot/dialects/bigquery.py b/altimate_packages/sqlglot/dialects/bigquery.py deleted file mode 100644 index df5966a84..000000000 --- a/altimate_packages/sqlglot/dialects/bigquery.py +++ /dev/null @@ -1,1331 +0,0 @@ -from __future__ import annotations - -import logging -import re -import typing as t - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot._typing import E -from sqlglot.dialects.dialect import ( - Dialect, - NormalizationStrategy, - annotate_with_type_lambda, - arg_max_or_min_no_count, - binary_from_function, - date_add_interval_sql, - datestrtodate_sql, - build_formatted_time, - filter_array_using_unnest, - if_sql, - inline_array_unless_query, - max_or_greatest, - min_or_least, - no_ilike_sql, - build_date_delta_with_interval, - regexp_replace_sql, - rename_func, - sha256_sql, - timestrtotime_sql, - ts_or_ds_add_cast, - unit_to_var, - strposition_sql, - groupconcat_sql, -) -from sqlglot.helper import seq_get, split_num_words -from sqlglot.tokens import TokenType -from sqlglot.generator import unsupported_args - -if t.TYPE_CHECKING: - from sqlglot._typing import Lit - - from sqlglot.optimizer.annotate_types import TypeAnnotator - -logger = logging.getLogger("sqlglot") - - -JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar, exp.JSONExtractArray] - -DQUOTES_ESCAPING_JSON_FUNCTIONS = ("JSON_QUERY", "JSON_VALUE", "JSON_QUERY_ARRAY") - - -def _derived_table_values_to_unnest(self: BigQuery.Generator, expression: exp.Values) -> str: - if not expression.find_ancestor(exp.From, exp.Join): - return self.values_sql(expression) - - structs = [] - alias = expression.args.get("alias") - for tup in expression.find_all(exp.Tuple): - field_aliases = ( - alias.columns - if alias and alias.columns - else (f"_c{i}" for i in range(len(tup.expressions))) - ) - expressions = [ - exp.PropertyEQ(this=exp.to_identifier(name), expression=fld) - for name, fld in zip(field_aliases, tup.expressions) - ] - structs.append(exp.Struct(expressions=expressions)) - - # Due to `UNNEST_COLUMN_ONLY`, it is expected that the table alias be contained in the columns expression - alias_name_only = exp.TableAlias(columns=[alias.this]) if alias else None - return self.unnest_sql( - exp.Unnest(expressions=[exp.array(*structs, copy=False)], alias=alias_name_only) - ) - - -def _returnsproperty_sql(self: BigQuery.Generator, expression: exp.ReturnsProperty) -> str: - this = expression.this - if isinstance(this, exp.Schema): - this = f"{self.sql(this, 'this')} <{self.expressions(this)}>" - else: - this = self.sql(this) - return f"RETURNS {this}" - - -def _create_sql(self: BigQuery.Generator, expression: exp.Create) -> str: - returns = expression.find(exp.ReturnsProperty) - if expression.kind == "FUNCTION" and returns and returns.args.get("is_table"): - expression.set("kind", "TABLE FUNCTION") - - if isinstance(expression.expression, (exp.Subquery, exp.Literal)): - expression.set("expression", expression.expression.this) - - return self.create_sql(expression) - - -# https://issuetracker.google.com/issues/162294746 -# workaround for bigquery bug when grouping by an expression and then ordering -# WITH x AS (SELECT 1 y) -# SELECT y + 1 z -# FROM x -# GROUP BY x + 1 -# ORDER by z -def _alias_ordered_group(expression: exp.Expression) -> exp.Expression: - if isinstance(expression, exp.Select): - group = expression.args.get("group") - order = expression.args.get("order") - - if group and order: - aliases = { - select.this: select.args["alias"] - for select in expression.selects - if isinstance(select, exp.Alias) - } - - for grouped in group.expressions: - if grouped.is_int: - continue - alias = aliases.get(grouped) - if alias: - grouped.replace(exp.column(alias)) - - return expression - - -def _pushdown_cte_column_names(expression: exp.Expression) -> exp.Expression: - """BigQuery doesn't allow column names when defining a CTE, so we try to push them down.""" - if isinstance(expression, exp.CTE) and expression.alias_column_names: - cte_query = expression.this - - if cte_query.is_star: - logger.warning( - "Can't push down CTE column names for star queries. Run the query through" - " the optimizer or use 'qualify' to expand the star projections first." - ) - return expression - - column_names = expression.alias_column_names - expression.args["alias"].set("columns", None) - - for name, select in zip(column_names, cte_query.selects): - to_replace = select - - if isinstance(select, exp.Alias): - select = select.this - - # Inner aliases are shadowed by the CTE column names - to_replace.replace(exp.alias_(select, name)) - - return expression - - -def _build_parse_timestamp(args: t.List) -> exp.StrToTime: - this = build_formatted_time(exp.StrToTime, "bigquery")([seq_get(args, 1), seq_get(args, 0)]) - this.set("zone", seq_get(args, 2)) - return this - - -def _build_timestamp(args: t.List) -> exp.Timestamp: - timestamp = exp.Timestamp.from_arg_list(args) - timestamp.set("with_tz", True) - return timestamp - - -def _build_date(args: t.List) -> exp.Date | exp.DateFromParts: - expr_type = exp.DateFromParts if len(args) == 3 else exp.Date - return expr_type.from_arg_list(args) - - -def _build_to_hex(args: t.List) -> exp.Hex | exp.MD5: - # TO_HEX(MD5(..)) is common in BigQuery, so it's parsed into MD5 to simplify its transpilation - arg = seq_get(args, 0) - return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.LowerHex(this=arg) - - -def _array_contains_sql(self: BigQuery.Generator, expression: exp.ArrayContains) -> str: - return self.sql( - exp.Exists( - this=exp.select("1") - .from_(exp.Unnest(expressions=[expression.left]).as_("_unnest", table=["_col"])) - .where(exp.column("_col").eq(expression.right)) - ) - ) - - -def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> str: - return date_add_interval_sql("DATE", "ADD")(self, ts_or_ds_add_cast(expression)) - - -def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str: - expression.this.replace(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)) - expression.expression.replace(exp.cast(expression.expression, exp.DataType.Type.TIMESTAMP)) - unit = unit_to_var(expression) - return self.func("DATE_DIFF", expression.this, expression.expression, unit) - - -def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = expression.this - - if scale in (None, exp.UnixToTime.SECONDS): - return self.func("TIMESTAMP_SECONDS", timestamp) - if scale == exp.UnixToTime.MILLIS: - return self.func("TIMESTAMP_MILLIS", timestamp) - if scale == exp.UnixToTime.MICROS: - return self.func("TIMESTAMP_MICROS", timestamp) - - unix_seconds = exp.cast( - exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), exp.DataType.Type.BIGINT - ) - return self.func("TIMESTAMP_SECONDS", unix_seconds) - - -def _build_time(args: t.List) -> exp.Func: - if len(args) == 1: - return exp.TsOrDsToTime(this=args[0]) - if len(args) == 2: - return exp.Time.from_arg_list(args) - return exp.TimeFromParts.from_arg_list(args) - - -def _build_datetime(args: t.List) -> exp.Func: - if len(args) == 1: - return exp.TsOrDsToDatetime.from_arg_list(args) - if len(args) == 2: - return exp.Datetime.from_arg_list(args) - return exp.TimestampFromParts.from_arg_list(args) - - -def _build_regexp_extract( - expr_type: t.Type[E], default_group: t.Optional[exp.Expression] = None -) -> t.Callable[[t.List], E]: - def _builder(args: t.List) -> E: - try: - group = re.compile(args[1].name).groups == 1 - except re.error: - group = False - - # Default group is used for the transpilation of REGEXP_EXTRACT_ALL - return expr_type( - this=seq_get(args, 0), - expression=seq_get(args, 1), - position=seq_get(args, 2), - occurrence=seq_get(args, 3), - group=exp.Literal.number(1) if group else default_group, - ) - - return _builder - - -def _build_extract_json_with_default_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: - def _builder(args: t.List, dialect: Dialect) -> E: - if len(args) == 1: - # The default value for the JSONPath is '$' i.e all of the data - args.append(exp.Literal.string("$")) - return parser.build_extract_json_with_path(expr_type)(args, dialect) - - return _builder - - -def _str_to_datetime_sql( - self: BigQuery.Generator, expression: exp.StrToDate | exp.StrToTime -) -> str: - this = self.sql(expression, "this") - dtype = "DATE" if isinstance(expression, exp.StrToDate) else "TIMESTAMP" - - if expression.args.get("safe"): - fmt = self.format_time( - expression, - self.dialect.INVERSE_FORMAT_MAPPING, - self.dialect.INVERSE_FORMAT_TRIE, - ) - return f"SAFE_CAST({this} AS {dtype} FORMAT {fmt})" - - fmt = self.format_time(expression) - return self.func(f"PARSE_{dtype}", fmt, this, expression.args.get("zone")) - - -def _annotate_math_functions(self: TypeAnnotator, expression: E) -> E: - """ - Many BigQuery math functions such as CEIL, FLOOR etc follow this return type convention: - +---------+---------+---------+------------+---------+ - | INPUT | INT64 | NUMERIC | BIGNUMERIC | FLOAT64 | - +---------+---------+---------+------------+---------+ - | OUTPUT | FLOAT64 | NUMERIC | BIGNUMERIC | FLOAT64 | - +---------+---------+---------+------------+---------+ - """ - self._annotate_args(expression) - - this: exp.Expression = expression.this - - self._set_type( - expression, - exp.DataType.Type.DOUBLE if this.is_type(*exp.DataType.INTEGER_TYPES) else this.type, - ) - return expression - - -@unsupported_args("ins_cost", "del_cost", "sub_cost") -def _levenshtein_sql(self: BigQuery.Generator, expression: exp.Levenshtein) -> str: - max_dist = expression.args.get("max_dist") - if max_dist: - max_dist = exp.Kwarg(this=exp.var("max_distance"), expression=max_dist) - - return self.func("EDIT_DISTANCE", expression.this, expression.expression, max_dist) - - -def _build_levenshtein(args: t.List) -> exp.Levenshtein: - max_dist = seq_get(args, 2) - return exp.Levenshtein( - this=seq_get(args, 0), - expression=seq_get(args, 1), - max_dist=max_dist.expression if max_dist else None, - ) - - -def _build_format_time(expr_type: t.Type[exp.Expression]) -> t.Callable[[t.List], exp.TimeToStr]: - def _builder(args: t.List) -> exp.TimeToStr: - return exp.TimeToStr( - this=expr_type(this=seq_get(args, 1)), - format=seq_get(args, 0), - zone=seq_get(args, 2), - ) - - return _builder - - -def _build_contains_substring(args: t.List) -> exp.Contains | exp.Anonymous: - if len(args) == 3: - return exp.Anonymous(this="CONTAINS_SUBSTR", expressions=args) - - # Lowercase the operands in case of transpilation, as exp.Contains - # is case-sensitive on other dialects - this = exp.Lower(this=seq_get(args, 0)) - expr = exp.Lower(this=seq_get(args, 1)) - - return exp.Contains(this=this, expression=expr) - - -def _json_extract_sql(self: BigQuery.Generator, expression: JSON_EXTRACT_TYPE) -> str: - name = (expression._meta and expression.meta.get("name")) or expression.sql_name() - upper = name.upper() - - dquote_escaping = upper in DQUOTES_ESCAPING_JSON_FUNCTIONS - - if dquote_escaping: - self._quote_json_path_key_using_brackets = False - - sql = rename_func(upper)(self, expression) - - if dquote_escaping: - self._quote_json_path_key_using_brackets = True - - return sql - - -def _annotate_concat(self: TypeAnnotator, expression: exp.Concat) -> exp.Concat: - annotated = self._annotate_by_args(expression, "expressions") - - # Args must be BYTES or types that can be cast to STRING, return type is either BYTES or STRING - # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#concat - if not annotated.is_type(exp.DataType.Type.BINARY, exp.DataType.Type.UNKNOWN): - annotated.type = exp.DataType.Type.VARCHAR - - return annotated - - -class BigQuery(Dialect): - WEEK_OFFSET = -1 - UNNEST_COLUMN_ONLY = True - SUPPORTS_USER_DEFINED_TYPES = False - SUPPORTS_SEMI_ANTI_JOIN = False - LOG_BASE_FIRST = False - HEX_LOWERCASE = True - FORCE_EARLY_ALIAS_REF_EXPANSION = True - PRESERVE_ORIGINAL_NAMES = True - HEX_STRING_IS_INTEGER_TYPE = True - - # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity - NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE - - # bigquery udfs are case sensitive - NORMALIZE_FUNCTIONS = False - - # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_elements_date_time - TIME_MAPPING = { - "%D": "%m/%d/%y", - "%E6S": "%S.%f", - "%e": "%-d", - } - - FORMAT_MAPPING = { - "DD": "%d", - "MM": "%m", - "MON": "%b", - "MONTH": "%B", - "YYYY": "%Y", - "YY": "%y", - "HH": "%I", - "HH12": "%I", - "HH24": "%H", - "MI": "%M", - "SS": "%S", - "SSSSS": "%f", - "TZH": "%z", - } - - # The _PARTITIONTIME and _PARTITIONDATE pseudo-columns are not returned by a SELECT * statement - # https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table - PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"} - - # All set operations require either a DISTINCT or ALL specifier - SET_OP_DISTINCT_BY_DEFAULT = dict.fromkeys((exp.Except, exp.Intersect, exp.Union), None) - - # BigQuery maps Type.TIMESTAMP to DATETIME, so we need to amend the inferred types - TYPE_TO_EXPRESSIONS = { - **Dialect.TYPE_TO_EXPRESSIONS, - exp.DataType.Type.TIMESTAMPTZ: Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.TIMESTAMP], - } - TYPE_TO_EXPRESSIONS.pop(exp.DataType.Type.TIMESTAMP) - - ANNOTATORS = { - **Dialect.ANNOTATORS, - **{ - expr_type: annotate_with_type_lambda(data_type) - for data_type, expressions in TYPE_TO_EXPRESSIONS.items() - for expr_type in expressions - }, - **{ - expr_type: lambda self, e: _annotate_math_functions(self, e) - for expr_type in (exp.Floor, exp.Ceil, exp.Log, exp.Ln, exp.Sqrt, exp.Exp, exp.Round) - }, - **{ - expr_type: lambda self, e: self._annotate_by_args(e, "this") - for expr_type in ( - exp.Left, - exp.Right, - exp.Lower, - exp.Upper, - exp.Pad, - exp.Trim, - exp.RegexpExtract, - exp.RegexpReplace, - exp.Repeat, - exp.Substring, - ) - }, - exp.Concat: _annotate_concat, - exp.Sign: lambda self, e: self._annotate_by_args(e, "this"), - exp.Split: lambda self, e: self._annotate_by_args(e, "this", array=True), - } - - def normalize_identifier(self, expression: E) -> E: - if ( - isinstance(expression, exp.Identifier) - and self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE - ): - parent = expression.parent - while isinstance(parent, exp.Dot): - parent = parent.parent - - # In BigQuery, CTEs are case-insensitive, but UDF and table names are case-sensitive - # by default. The following check uses a heuristic to detect tables based on whether - # they are qualified. This should generally be correct, because tables in BigQuery - # must be qualified with at least a dataset, unless @@dataset_id is set. - case_sensitive = ( - isinstance(parent, exp.UserDefinedFunction) - or ( - isinstance(parent, exp.Table) - and parent.db - and (parent.meta.get("quoted_table") or not parent.meta.get("maybe_column")) - ) - or expression.meta.get("is_table") - ) - if not case_sensitive: - expression.set("this", expression.this.lower()) - - return t.cast(E, expression) - - return super().normalize_identifier(expression) - - class Tokenizer(tokens.Tokenizer): - QUOTES = ["'", '"', '"""', "'''"] - COMMENTS = ["--", "#", ("/*", "*/")] - IDENTIFIERS = ["`"] - STRING_ESCAPES = ["\\"] - - HEX_STRINGS = [("0x", ""), ("0X", "")] - - BYTE_STRINGS = [ - (prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("b", "B") - ] - - RAW_STRINGS = [ - (prefix + q, q) for q in t.cast(t.List[str], QUOTES) for prefix in ("r", "R") - ] - - NESTED_COMMENTS = False - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - "ANY TYPE": TokenType.VARIANT, - "BEGIN": TokenType.COMMAND, - "BEGIN TRANSACTION": TokenType.BEGIN, - "BYTEINT": TokenType.INT, - "BYTES": TokenType.BINARY, - "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, - "DATETIME": TokenType.TIMESTAMP, - "DECLARE": TokenType.COMMAND, - "ELSEIF": TokenType.COMMAND, - "EXCEPTION": TokenType.COMMAND, - "EXPORT": TokenType.EXPORT, - "FLOAT64": TokenType.DOUBLE, - "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, - "MODEL": TokenType.MODEL, - "NOT DETERMINISTIC": TokenType.VOLATILE, - "RECORD": TokenType.STRUCT, - "TIMESTAMP": TokenType.TIMESTAMPTZ, - } - KEYWORDS.pop("DIV") - KEYWORDS.pop("VALUES") - KEYWORDS.pop("/*+") - - class Parser(parser.Parser): - PREFIXED_PIVOT_COLUMNS = True - LOG_DEFAULTS_TO_LN = True - SUPPORTS_IMPLICIT_UNNEST = True - - # BigQuery does not allow ASC/DESC to be used as an identifier - ID_VAR_TOKENS = parser.Parser.ID_VAR_TOKENS - {TokenType.ASC, TokenType.DESC} - ALIAS_TOKENS = parser.Parser.ALIAS_TOKENS - {TokenType.ASC, TokenType.DESC} - TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {TokenType.ASC, TokenType.DESC} - COMMENT_TABLE_ALIAS_TOKENS = parser.Parser.COMMENT_TABLE_ALIAS_TOKENS - { - TokenType.ASC, - TokenType.DESC, - } - UPDATE_ALIAS_TOKENS = parser.Parser.UPDATE_ALIAS_TOKENS - {TokenType.ASC, TokenType.DESC} - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "CONTAINS_SUBSTR": _build_contains_substring, - "DATE": _build_date, - "DATE_ADD": build_date_delta_with_interval(exp.DateAdd), - "DATE_SUB": build_date_delta_with_interval(exp.DateSub), - "DATE_TRUNC": lambda args: exp.DateTrunc( - unit=exp.Literal.string(str(seq_get(args, 1))), - this=seq_get(args, 0), - zone=seq_get(args, 2), - ), - "DATETIME": _build_datetime, - "DATETIME_ADD": build_date_delta_with_interval(exp.DatetimeAdd), - "DATETIME_SUB": build_date_delta_with_interval(exp.DatetimeSub), - "DIV": binary_from_function(exp.IntDiv), - "EDIT_DISTANCE": _build_levenshtein, - "FORMAT_DATE": _build_format_time(exp.TsOrDsToDate), - "GENERATE_ARRAY": exp.GenerateSeries.from_arg_list, - "JSON_EXTRACT_SCALAR": _build_extract_json_with_default_path(exp.JSONExtractScalar), - "JSON_EXTRACT_ARRAY": _build_extract_json_with_default_path(exp.JSONExtractArray), - "JSON_QUERY": parser.build_extract_json_with_path(exp.JSONExtract), - "JSON_QUERY_ARRAY": _build_extract_json_with_default_path(exp.JSONExtractArray), - "JSON_VALUE": _build_extract_json_with_default_path(exp.JSONExtractScalar), - "JSON_VALUE_ARRAY": _build_extract_json_with_default_path(exp.JSONValueArray), - "LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True), - "MD5": exp.MD5Digest.from_arg_list, - "TO_HEX": _build_to_hex, - "PARSE_DATE": lambda args: build_formatted_time(exp.StrToDate, "bigquery")( - [seq_get(args, 1), seq_get(args, 0)] - ), - "PARSE_TIMESTAMP": _build_parse_timestamp, - "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list, - "REGEXP_EXTRACT": _build_regexp_extract(exp.RegexpExtract), - "REGEXP_SUBSTR": _build_regexp_extract(exp.RegexpExtract), - "REGEXP_EXTRACT_ALL": _build_regexp_extract( - exp.RegexpExtractAll, default_group=exp.Literal.number(0) - ), - "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), - "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)), - "SPLIT": lambda args: exp.Split( - # https://cloud.google.com/bigquery/docs/reference/standard-sql/string_functions#split - this=seq_get(args, 0), - expression=seq_get(args, 1) or exp.Literal.string(","), - ), - "STRPOS": exp.StrPosition.from_arg_list, - "TIME": _build_time, - "TIME_ADD": build_date_delta_with_interval(exp.TimeAdd), - "TIME_SUB": build_date_delta_with_interval(exp.TimeSub), - "TIMESTAMP": _build_timestamp, - "TIMESTAMP_ADD": build_date_delta_with_interval(exp.TimestampAdd), - "TIMESTAMP_SUB": build_date_delta_with_interval(exp.TimestampSub), - "TIMESTAMP_MICROS": lambda args: exp.UnixToTime( - this=seq_get(args, 0), scale=exp.UnixToTime.MICROS - ), - "TIMESTAMP_MILLIS": lambda args: exp.UnixToTime( - this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS - ), - "TIMESTAMP_SECONDS": lambda args: exp.UnixToTime(this=seq_get(args, 0)), - "TO_JSON_STRING": exp.JSONFormat.from_arg_list, - "FORMAT_DATETIME": _build_format_time(exp.TsOrDsToDatetime), - "FORMAT_TIMESTAMP": _build_format_time(exp.TsOrDsToTimestamp), - } - - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, - "ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]), - "MAKE_INTERVAL": lambda self: self._parse_make_interval(), - "FEATURES_AT_TIME": lambda self: self._parse_features_at_time(), - } - FUNCTION_PARSERS.pop("TRIM") - - NO_PAREN_FUNCTIONS = { - **parser.Parser.NO_PAREN_FUNCTIONS, - TokenType.CURRENT_DATETIME: exp.CurrentDatetime, - } - - NESTED_TYPE_TOKENS = { - *parser.Parser.NESTED_TYPE_TOKENS, - TokenType.TABLE, - } - - PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, - "NOT DETERMINISTIC": lambda self: self.expression( - exp.StabilityProperty, this=exp.Literal.string("VOLATILE") - ), - "OPTIONS": lambda self: self._parse_with_property(), - } - - CONSTRAINT_PARSERS = { - **parser.Parser.CONSTRAINT_PARSERS, - "OPTIONS": lambda self: exp.Properties(expressions=self._parse_with_property()), - } - - RANGE_PARSERS = parser.Parser.RANGE_PARSERS.copy() - RANGE_PARSERS.pop(TokenType.OVERLAPS) - - NULL_TOKENS = {TokenType.NULL, TokenType.UNKNOWN} - - DASHED_TABLE_PART_FOLLOW_TOKENS = {TokenType.DOT, TokenType.L_PAREN, TokenType.R_PAREN} - - STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, - TokenType.ELSE: lambda self: self._parse_as_command(self._prev), - TokenType.END: lambda self: self._parse_as_command(self._prev), - TokenType.FOR: lambda self: self._parse_for_in(), - TokenType.EXPORT: lambda self: self._parse_export_data(), - } - - BRACKET_OFFSETS = { - "OFFSET": (0, False), - "ORDINAL": (1, False), - "SAFE_OFFSET": (0, True), - "SAFE_ORDINAL": (1, True), - } - - def _parse_for_in(self) -> exp.ForIn: - this = self._parse_range() - self._match_text_seq("DO") - return self.expression(exp.ForIn, this=this, expression=self._parse_statement()) - - def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: - this = super()._parse_table_part(schema=schema) or self._parse_number() - - # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#table_names - if isinstance(this, exp.Identifier): - table_name = this.name - while self._match(TokenType.DASH, advance=False) and self._next: - start = self._curr - while self._is_connected() and not self._match_set( - self.DASHED_TABLE_PART_FOLLOW_TOKENS, advance=False - ): - self._advance() - - if start == self._curr: - break - - table_name += self._find_sql(start, self._prev) - - this = exp.Identifier( - this=table_name, quoted=this.args.get("quoted") - ).update_positions(this) - elif isinstance(this, exp.Literal): - table_name = this.name - - if self._is_connected() and self._parse_var(any_token=True): - table_name += self._prev.text - - this = exp.Identifier(this=table_name, quoted=True).update_positions(this) - - return this - - def _parse_table_parts( - self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False - ) -> exp.Table: - table = super()._parse_table_parts( - schema=schema, is_db_reference=is_db_reference, wildcard=True - ) - - # proj-1.db.tbl -- `1.` is tokenized as a float so we need to unravel it here - if not table.catalog: - if table.db: - previous_db = table.args["db"] - parts = table.db.split(".") - if len(parts) == 2 and not table.args["db"].quoted: - table.set( - "catalog", exp.Identifier(this=parts[0]).update_positions(previous_db) - ) - table.set("db", exp.Identifier(this=parts[1]).update_positions(previous_db)) - else: - previous_this = table.this - parts = table.name.split(".") - if len(parts) == 2 and not table.this.quoted: - table.set( - "db", exp.Identifier(this=parts[0]).update_positions(previous_this) - ) - table.set( - "this", exp.Identifier(this=parts[1]).update_positions(previous_this) - ) - - if isinstance(table.this, exp.Identifier) and any("." in p.name for p in table.parts): - alias = table.this - catalog, db, this, *rest = ( - exp.to_identifier(p, quoted=True) - for p in split_num_words(".".join(p.name for p in table.parts), ".", 3) - ) - - for part in (catalog, db, this): - if part: - part.update_positions(table.this) - - if rest and this: - this = exp.Dot.build([this, *rest]) # type: ignore - - table = exp.Table( - this=this, db=db, catalog=catalog, pivots=table.args.get("pivots") - ) - table.meta["quoted_table"] = True - else: - alias = None - - # The `INFORMATION_SCHEMA` views in BigQuery need to be qualified by a region or - # dataset, so if the project identifier is omitted we need to fix the ast so that - # the `INFORMATION_SCHEMA.X` bit is represented as a single (quoted) Identifier. - # Otherwise, we wouldn't correctly qualify a `Table` node that references these - # views, because it would seem like the "catalog" part is set, when it'd actually - # be the region/dataset. Merging the two identifiers into a single one is done to - # avoid producing a 4-part Table reference, which would cause issues in the schema - # module, when there are 3-part table names mixed with information schema views. - # - # See: https://cloud.google.com/bigquery/docs/information-schema-intro#syntax - table_parts = table.parts - if len(table_parts) > 1 and table_parts[-2].name.upper() == "INFORMATION_SCHEMA": - # We need to alias the table here to avoid breaking existing qualified columns. - # This is expected to be safe, because if there's an actual alias coming up in - # the token stream, it will overwrite this one. If there isn't one, we are only - # exposing the name that can be used to reference the view explicitly (a no-op). - exp.alias_( - table, - t.cast(exp.Identifier, alias or table_parts[-1]), - table=True, - copy=False, - ) - - info_schema_view = f"{table_parts[-2].name}.{table_parts[-1].name}" - new_this = exp.Identifier(this=info_schema_view, quoted=True).update_positions( - line=table_parts[-2].meta.get("line"), - col=table_parts[-1].meta.get("col"), - start=table_parts[-2].meta.get("start"), - end=table_parts[-1].meta.get("end"), - ) - table.set("this", new_this) - table.set("db", seq_get(table_parts, -3)) - table.set("catalog", seq_get(table_parts, -4)) - - return table - - def _parse_column(self) -> t.Optional[exp.Expression]: - column = super()._parse_column() - if isinstance(column, exp.Column): - parts = column.parts - if any("." in p.name for p in parts): - catalog, db, table, this, *rest = ( - exp.to_identifier(p, quoted=True) - for p in split_num_words(".".join(p.name for p in parts), ".", 4) - ) - - if rest and this: - this = exp.Dot.build([this, *rest]) # type: ignore - - column = exp.Column(this=this, table=table, db=db, catalog=catalog) - column.meta["quoted_column"] = True - - return column - - @t.overload - def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... - - @t.overload - def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ... - - def _parse_json_object(self, agg=False): - json_object = super()._parse_json_object() - array_kv_pair = seq_get(json_object.expressions, 0) - - # Converts BQ's "signature 2" of JSON_OBJECT into SQLGlot's canonical representation - # https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#json_object_signature2 - if ( - array_kv_pair - and isinstance(array_kv_pair.this, exp.Array) - and isinstance(array_kv_pair.expression, exp.Array) - ): - keys = array_kv_pair.this.expressions - values = array_kv_pair.expression.expressions - - json_object.set( - "expressions", - [exp.JSONKeyValue(this=k, expression=v) for k, v in zip(keys, values)], - ) - - return json_object - - def _parse_bracket( - self, this: t.Optional[exp.Expression] = None - ) -> t.Optional[exp.Expression]: - bracket = super()._parse_bracket(this) - - if this is bracket: - return bracket - - if isinstance(bracket, exp.Bracket): - for expression in bracket.expressions: - name = expression.name.upper() - - if name not in self.BRACKET_OFFSETS: - break - - offset, safe = self.BRACKET_OFFSETS[name] - bracket.set("offset", offset) - bracket.set("safe", safe) - expression.replace(expression.expressions[0]) - - return bracket - - def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: - unnest = super()._parse_unnest(with_alias=with_alias) - - if not unnest: - return None - - unnest_expr = seq_get(unnest.expressions, 0) - if unnest_expr: - from sqlglot.optimizer.annotate_types import annotate_types - - unnest_expr = annotate_types(unnest_expr, dialect=self.dialect) - - # Unnesting a nested array (i.e array of structs) explodes the top-level struct fields, - # in contrast to other dialects such as DuckDB which flattens only the array by default - if unnest_expr.is_type(exp.DataType.Type.ARRAY) and any( - array_elem.is_type(exp.DataType.Type.STRUCT) - for array_elem in unnest_expr._type.expressions - ): - unnest.set("explode_array", True) - - return unnest - - def _parse_make_interval(self) -> exp.MakeInterval: - expr = exp.MakeInterval() - - for arg_key in expr.arg_types: - value = self._parse_lambda() - - if not value: - break - - # Non-named arguments are filled sequentially, (optionally) followed by named arguments - # that can appear in any order e.g MAKE_INTERVAL(1, minute => 5, day => 2) - if isinstance(value, exp.Kwarg): - arg_key = value.this.name - - expr.set(arg_key, value) - - self._match(TokenType.COMMA) - - return expr - - def _parse_features_at_time(self) -> exp.FeaturesAtTime: - expr = self.expression( - exp.FeaturesAtTime, - this=(self._match(TokenType.TABLE) and self._parse_table()) - or self._parse_select(nested=True), - ) - - while self._match(TokenType.COMMA): - arg = self._parse_lambda() - - # Get the LHS of the Kwarg and set the arg to that value, e.g - # "num_rows => 1" sets the expr's `num_rows` arg - if arg: - expr.set(arg.this.name, arg) - - return expr - - def _parse_export_data(self) -> exp.Export: - self._match_text_seq("DATA") - - return self.expression( - exp.Export, - connection=self._match_text_seq("WITH", "CONNECTION") and self._parse_table_parts(), - options=self._parse_properties(), - this=self._match_text_seq("AS") and self._parse_select(), - ) - - class Generator(generator.Generator): - INTERVAL_ALLOWS_PLURAL_FORM = False - JOIN_HINTS = False - QUERY_HINTS = False - TABLE_HINTS = False - LIMIT_FETCH = "LIMIT" - RENAME_TABLE_WITH_DB = False - NVL2_SUPPORTED = False - UNNEST_WITH_ORDINALITY = False - COLLATE_IS_FUNC = True - LIMIT_ONLY_LITERALS = True - SUPPORTS_TABLE_ALIAS_COLUMNS = False - UNPIVOT_ALIASES_ARE_IDENTIFIERS = False - JSON_KEY_VALUE_PAIR_SEP = "," - NULL_ORDERING_SUPPORTED = False - IGNORE_NULLS_IN_FUNC = True - JSON_PATH_SINGLE_QUOTE_ESCAPE = True - CAN_IMPLEMENT_ARRAY_ANY = True - SUPPORTS_TO_NUMBER = False - NAMED_PLACEHOLDER_TOKEN = "@" - HEX_FUNC = "TO_HEX" - WITH_PROPERTIES_PREFIX = "OPTIONS" - SUPPORTS_EXPLODING_PROJECTIONS = False - EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE = False - SUPPORTS_UNIX_SECONDS = True - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), - exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), - exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), - exp.Array: inline_array_unless_query, - exp.ArrayContains: _array_contains_sql, - exp.ArrayFilter: filter_array_using_unnest, - exp.ArrayRemove: filter_array_using_unnest, - exp.Cast: transforms.preprocess([transforms.remove_precision_parameterized_types]), - exp.CollateProperty: lambda self, e: ( - f"DEFAULT COLLATE {self.sql(e, 'this')}" - if e.args.get("default") - else f"COLLATE {self.sql(e, 'this')}" - ), - exp.Commit: lambda *_: "COMMIT TRANSACTION", - exp.CountIf: rename_func("COUNTIF"), - exp.Create: _create_sql, - exp.CTE: transforms.preprocess([_pushdown_cte_column_names]), - exp.DateAdd: date_add_interval_sql("DATE", "ADD"), - exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", e.this, e.expression, unit_to_var(e) - ), - exp.DateFromParts: rename_func("DATE"), - exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: date_add_interval_sql("DATE", "SUB"), - exp.DatetimeAdd: date_add_interval_sql("DATETIME", "ADD"), - exp.DatetimeSub: date_add_interval_sql("DATETIME", "SUB"), - exp.DateTrunc: lambda self, e: self.func( - "DATE_TRUNC", e.this, e.text("unit"), e.args.get("zone") - ), - exp.FromTimeZone: lambda self, e: self.func( - "DATETIME", self.func("TIMESTAMP", e.this, e.args.get("zone")), "'UTC'" - ), - exp.GenerateSeries: rename_func("GENERATE_ARRAY"), - exp.GroupConcat: lambda self, e: groupconcat_sql( - self, e, func_name="STRING_AGG", within_group=False - ), - exp.Hex: lambda self, e: self.func("UPPER", self.func("TO_HEX", self.sql(e, "this"))), - exp.HexString: lambda self, e: self.hexstring_sql(e, binary_function_repr="FROM_HEX"), - exp.If: if_sql(false_value="NULL"), - exp.ILike: no_ilike_sql, - exp.IntDiv: rename_func("DIV"), - exp.Int64: rename_func("INT64"), - exp.JSONExtract: _json_extract_sql, - exp.JSONExtractArray: _json_extract_sql, - exp.JSONExtractScalar: _json_extract_sql, - exp.JSONFormat: rename_func("TO_JSON_STRING"), - exp.Levenshtein: _levenshtein_sql, - exp.Max: max_or_greatest, - exp.MD5: lambda self, e: self.func("TO_HEX", self.func("MD5", e.this)), - exp.MD5Digest: rename_func("MD5"), - exp.Min: min_or_least, - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", - exp.RegexpExtract: lambda self, e: self.func( - "REGEXP_EXTRACT", - e.this, - e.expression, - e.args.get("position"), - e.args.get("occurrence"), - ), - exp.RegexpExtractAll: lambda self, e: self.func( - "REGEXP_EXTRACT_ALL", e.this, e.expression - ), - exp.RegexpReplace: regexp_replace_sql, - exp.RegexpLike: rename_func("REGEXP_CONTAINS"), - exp.ReturnsProperty: _returnsproperty_sql, - exp.Rollback: lambda *_: "ROLLBACK TRANSACTION", - exp.Select: transforms.preprocess( - [ - transforms.explode_projection_to_unnest(), - transforms.unqualify_unnest, - transforms.eliminate_distinct_on, - _alias_ordered_group, - transforms.eliminate_semi_and_anti_joins, - ] - ), - exp.SHA: rename_func("SHA1"), - exp.SHA2: sha256_sql, - exp.StabilityProperty: lambda self, e: ( - "DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC" - ), - exp.String: rename_func("STRING"), - exp.StrPosition: lambda self, e: ( - strposition_sql( - self, e, func_name="INSTR", supports_position=True, supports_occurrence=True - ) - ), - exp.StrToDate: _str_to_datetime_sql, - exp.StrToTime: _str_to_datetime_sql, - exp.TimeAdd: date_add_interval_sql("TIME", "ADD"), - exp.TimeFromParts: rename_func("TIME"), - exp.TimestampFromParts: rename_func("DATETIME"), - exp.TimeSub: date_add_interval_sql("TIME", "SUB"), - exp.TimestampAdd: date_add_interval_sql("TIMESTAMP", "ADD"), - exp.TimestampDiff: rename_func("TIMESTAMP_DIFF"), - exp.TimestampSub: date_add_interval_sql("TIMESTAMP", "SUB"), - exp.TimeStrToTime: timestrtotime_sql, - exp.Transaction: lambda *_: "BEGIN TRANSACTION", - exp.TsOrDsAdd: _ts_or_ds_add_sql, - exp.TsOrDsDiff: _ts_or_ds_diff_sql, - exp.TsOrDsToTime: rename_func("TIME"), - exp.TsOrDsToDatetime: rename_func("DATETIME"), - exp.TsOrDsToTimestamp: rename_func("TIMESTAMP"), - exp.Unhex: rename_func("FROM_HEX"), - exp.UnixDate: rename_func("UNIX_DATE"), - exp.UnixToTime: _unix_to_time_sql, - exp.Uuid: lambda *_: "GENERATE_UUID()", - exp.Values: _derived_table_values_to_unnest, - exp.VariancePop: rename_func("VAR_POP"), - exp.SafeDivide: rename_func("SAFE_DIVIDE"), - } - - SUPPORTED_JSON_PATH_PARTS = { - exp.JSONPathKey, - exp.JSONPathRoot, - exp.JSONPathSubscript, - } - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC", - exp.DataType.Type.BIGINT: "INT64", - exp.DataType.Type.BINARY: "BYTES", - exp.DataType.Type.BLOB: "BYTES", - exp.DataType.Type.BOOLEAN: "BOOL", - exp.DataType.Type.CHAR: "STRING", - exp.DataType.Type.DECIMAL: "NUMERIC", - exp.DataType.Type.DOUBLE: "FLOAT64", - exp.DataType.Type.FLOAT: "FLOAT64", - exp.DataType.Type.INT: "INT64", - exp.DataType.Type.NCHAR: "STRING", - exp.DataType.Type.NVARCHAR: "STRING", - exp.DataType.Type.SMALLINT: "INT64", - exp.DataType.Type.TEXT: "STRING", - exp.DataType.Type.TIMESTAMP: "DATETIME", - exp.DataType.Type.TIMESTAMPNTZ: "DATETIME", - exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", - exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", - exp.DataType.Type.TINYINT: "INT64", - exp.DataType.Type.ROWVERSION: "BYTES", - exp.DataType.Type.UUID: "STRING", - exp.DataType.Type.VARBINARY: "BYTES", - exp.DataType.Type.VARCHAR: "STRING", - exp.DataType.Type.VARIANT: "ANY TYPE", - } - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, - } - - # WINDOW comes after QUALIFY - # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#window_clause - AFTER_HAVING_MODIFIER_TRANSFORMS = { - "qualify": generator.Generator.AFTER_HAVING_MODIFIER_TRANSFORMS["qualify"], - "windows": generator.Generator.AFTER_HAVING_MODIFIER_TRANSFORMS["windows"], - } - - # from: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#reserved_keywords - RESERVED_KEYWORDS = { - "all", - "and", - "any", - "array", - "as", - "asc", - "assert_rows_modified", - "at", - "between", - "by", - "case", - "cast", - "collate", - "contains", - "create", - "cross", - "cube", - "current", - "default", - "define", - "desc", - "distinct", - "else", - "end", - "enum", - "escape", - "except", - "exclude", - "exists", - "extract", - "false", - "fetch", - "following", - "for", - "from", - "full", - "group", - "grouping", - "groups", - "hash", - "having", - "if", - "ignore", - "in", - "inner", - "intersect", - "interval", - "into", - "is", - "join", - "lateral", - "left", - "like", - "limit", - "lookup", - "merge", - "natural", - "new", - "no", - "not", - "null", - "nulls", - "of", - "on", - "or", - "order", - "outer", - "over", - "partition", - "preceding", - "proto", - "qualify", - "range", - "recursive", - "respect", - "right", - "rollup", - "rows", - "select", - "set", - "some", - "struct", - "tablesample", - "then", - "to", - "treat", - "true", - "unbounded", - "union", - "unnest", - "using", - "when", - "where", - "window", - "with", - "within", - } - - def mod_sql(self, expression: exp.Mod) -> str: - this = expression.this - expr = expression.expression - return self.func( - "MOD", - this.unnest() if isinstance(this, exp.Paren) else this, - expr.unnest() if isinstance(expr, exp.Paren) else expr, - ) - - def column_parts(self, expression: exp.Column) -> str: - if expression.meta.get("quoted_column"): - # If a column reference is of the form `dataset.table`.name, we need - # to preserve the quoted table path, otherwise the reference breaks - table_parts = ".".join(p.name for p in expression.parts[:-1]) - table_path = self.sql(exp.Identifier(this=table_parts, quoted=True)) - return f"{table_path}.{self.sql(expression, 'this')}" - - return super().column_parts(expression) - - def table_parts(self, expression: exp.Table) -> str: - # Depending on the context, `x.y` may not resolve to the same data source as `x`.`y`, so - # we need to make sure the correct quoting is used in each case. - # - # For example, if there is a CTE x that clashes with a schema name, then the former will - # return the table y in that schema, whereas the latter will return the CTE's y column: - # - # - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x.y` -> cross join - # - WITH x AS (SELECT [1, 2] AS y) SELECT * FROM x, `x`.`y` -> implicit unnest - if expression.meta.get("quoted_table"): - table_parts = ".".join(p.name for p in expression.parts) - return self.sql(exp.Identifier(this=table_parts, quoted=True)) - - return super().table_parts(expression) - - def timetostr_sql(self, expression: exp.TimeToStr) -> str: - this = expression.this - if isinstance(this, exp.TsOrDsToDatetime): - func_name = "FORMAT_DATETIME" - elif isinstance(this, exp.TsOrDsToTimestamp): - func_name = "FORMAT_TIMESTAMP" - else: - func_name = "FORMAT_DATE" - - time_expr = ( - this - if isinstance(this, (exp.TsOrDsToDatetime, exp.TsOrDsToTimestamp, exp.TsOrDsToDate)) - else expression - ) - return self.func( - func_name, self.format_time(expression), time_expr.this, expression.args.get("zone") - ) - - def eq_sql(self, expression: exp.EQ) -> str: - # Operands of = cannot be NULL in BigQuery - if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null): - if not isinstance(expression.parent, exp.Update): - return "NULL" - - return self.binary(expression, "=") - - def attimezone_sql(self, expression: exp.AtTimeZone) -> str: - parent = expression.parent - - # BigQuery allows CAST(.. AS {STRING|TIMESTAMP} [FORMAT [AT TIME ZONE ]]). - # Only the TIMESTAMP one should use the below conversion, when AT TIME ZONE is included. - if not isinstance(parent, exp.Cast) or not parent.to.is_type("text"): - return self.func( - "TIMESTAMP", self.func("DATETIME", expression.this, expression.args.get("zone")) - ) - - return super().attimezone_sql(expression) - - def trycast_sql(self, expression: exp.TryCast) -> str: - return self.cast_sql(expression, safe_prefix="SAFE_") - - def bracket_sql(self, expression: exp.Bracket) -> str: - this = expression.this - expressions = expression.expressions - - if len(expressions) == 1 and this and this.is_type(exp.DataType.Type.STRUCT): - arg = expressions[0] - if arg.type is None: - from sqlglot.optimizer.annotate_types import annotate_types - - arg = annotate_types(arg, dialect=self.dialect) - - if arg.type and arg.type.this in exp.DataType.TEXT_TYPES: - # BQ doesn't support bracket syntax with string values for structs - return f"{self.sql(this)}.{arg.name}" - - expressions_sql = self.expressions(expression, flat=True) - offset = expression.args.get("offset") - - if offset == 0: - expressions_sql = f"OFFSET({expressions_sql})" - elif offset == 1: - expressions_sql = f"ORDINAL({expressions_sql})" - elif offset is not None: - self.unsupported(f"Unsupported array offset: {offset}") - - if expression.args.get("safe"): - expressions_sql = f"SAFE_{expressions_sql}" - - return f"{self.sql(this)}[{expressions_sql}]" - - def in_unnest_op(self, expression: exp.Unnest) -> str: - return self.sql(expression) - - def version_sql(self, expression: exp.Version) -> str: - if expression.name == "TIMESTAMP": - expression.set("this", "SYSTEM_TIME") - return super().version_sql(expression) - - def contains_sql(self, expression: exp.Contains) -> str: - this = expression.this - expr = expression.expression - - if isinstance(this, exp.Lower) and isinstance(expr, exp.Lower): - this = this.this - expr = expr.this - - return self.func("CONTAINS_SUBSTR", this, expr) - - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - this = expression.this - - # This ensures that inline type-annotated ARRAY literals like ARRAY[1, 2, 3] - # are roundtripped unaffected. The inner check excludes ARRAY(SELECT ...) expressions, - # because they aren't literals and so the above syntax is invalid BigQuery. - if isinstance(this, exp.Array): - elem = seq_get(this.expressions, 0) - if not (elem and elem.find(exp.Query)): - return f"{self.sql(expression, 'to')}{self.sql(this)}" - - return super().cast_sql(expression, safe_prefix=safe_prefix) diff --git a/altimate_packages/sqlglot/dialects/clickhouse.py b/altimate_packages/sqlglot/dialects/clickhouse.py deleted file mode 100644 index c5d3cdf09..000000000 --- a/altimate_packages/sqlglot/dialects/clickhouse.py +++ /dev/null @@ -1,1393 +0,0 @@ -from __future__ import annotations -import typing as t -import datetime -from sqlglot import exp, generator, parser, tokens -from sqlglot.dialects.dialect import ( - Dialect, - NormalizationStrategy, - arg_max_or_min_no_count, - build_date_delta, - build_formatted_time, - inline_array_sql, - json_extract_segments, - json_path_key_only_name, - length_or_char_length_sql, - no_pivot_sql, - build_json_extract_path, - rename_func, - remove_from_array_using_filter, - sha256_sql, - strposition_sql, - var_map_sql, - timestamptrunc_sql, - unit_to_var, - trim_sql, -) -from sqlglot.generator import Generator -from sqlglot.helper import is_int, seq_get -from sqlglot.tokens import Token, TokenType -from sqlglot.generator import unsupported_args - -DATEฮคฮ™ฮœฮ•_DELTA = t.Union[exp.DateAdd, exp.DateDiff, exp.DateSub, exp.TimestampSub, exp.TimestampAdd] - - -def _build_date_format(args: t.List) -> exp.TimeToStr: - expr = build_formatted_time(exp.TimeToStr, "clickhouse")(args) - - timezone = seq_get(args, 2) - if timezone: - expr.set("zone", timezone) - - return expr - - -def _unix_to_time_sql(self: ClickHouse.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = expression.this - - if scale in (None, exp.UnixToTime.SECONDS): - return self.func("fromUnixTimestamp", exp.cast(timestamp, exp.DataType.Type.BIGINT)) - if scale == exp.UnixToTime.MILLIS: - return self.func("fromUnixTimestamp64Milli", exp.cast(timestamp, exp.DataType.Type.BIGINT)) - if scale == exp.UnixToTime.MICROS: - return self.func("fromUnixTimestamp64Micro", exp.cast(timestamp, exp.DataType.Type.BIGINT)) - if scale == exp.UnixToTime.NANOS: - return self.func("fromUnixTimestamp64Nano", exp.cast(timestamp, exp.DataType.Type.BIGINT)) - - return self.func( - "fromUnixTimestamp", - exp.cast( - exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), exp.DataType.Type.BIGINT - ), - ) - - -def _lower_func(sql: str) -> str: - index = sql.index("(") - return sql[:index].lower() + sql[index:] - - -def _quantile_sql(self: ClickHouse.Generator, expression: exp.Quantile) -> str: - quantile = expression.args["quantile"] - args = f"({self.sql(expression, 'this')})" - - if isinstance(quantile, exp.Array): - func = self.func("quantiles", *quantile) - else: - func = self.func("quantile", quantile) - - return func + args - - -def _build_count_if(args: t.List) -> exp.CountIf | exp.CombinedAggFunc: - if len(args) == 1: - return exp.CountIf(this=seq_get(args, 0)) - - return exp.CombinedAggFunc(this="countIf", expressions=args) - - -def _build_str_to_date(args: t.List) -> exp.Cast | exp.Anonymous: - if len(args) == 3: - return exp.Anonymous(this="STR_TO_DATE", expressions=args) - - strtodate = exp.StrToDate.from_arg_list(args) - return exp.cast(strtodate, exp.DataType.build(exp.DataType.Type.DATETIME)) - - -def _datetime_delta_sql(name: str) -> t.Callable[[Generator, DATEฮคฮ™ฮœฮ•_DELTA], str]: - def _delta_sql(self: Generator, expression: DATEฮคฮ™ฮœฮ•_DELTA) -> str: - if not expression.unit: - return rename_func(name)(self, expression) - - return self.func( - name, - unit_to_var(expression), - expression.expression, - expression.this, - expression.args.get("zone"), - ) - - return _delta_sql - - -def _timestrtotime_sql(self: ClickHouse.Generator, expression: exp.TimeStrToTime): - ts = expression.this - - tz = expression.args.get("zone") - if tz and isinstance(ts, exp.Literal): - # Clickhouse will not accept timestamps that include a UTC offset, so we must remove them. - # The first step to removing is parsing the string with `datetime.datetime.fromisoformat`. - # - # In python <3.11, `fromisoformat()` can only parse timestamps of millisecond (3 digit) - # or microsecond (6 digit) precision. It will error if passed any other number of fractional - # digits, so we extract the fractional seconds and pad to 6 digits before parsing. - ts_string = ts.name.strip() - - # separate [date and time] from [fractional seconds and UTC offset] - ts_parts = ts_string.split(".") - if len(ts_parts) == 2: - # separate fractional seconds and UTC offset - offset_sep = "+" if "+" in ts_parts[1] else "-" - ts_frac_parts = ts_parts[1].split(offset_sep) - num_frac_parts = len(ts_frac_parts) - - # pad to 6 digits if fractional seconds present - ts_frac_parts[0] = ts_frac_parts[0].ljust(6, "0") - ts_string = "".join( - [ - ts_parts[0], # date and time - ".", - ts_frac_parts[0], # fractional seconds - offset_sep if num_frac_parts > 1 else "", - ts_frac_parts[1] if num_frac_parts > 1 else "", # utc offset (if present) - ] - ) - - # return literal with no timezone, eg turn '2020-01-01 12:13:14-08:00' into '2020-01-01 12:13:14' - # this is because Clickhouse encodes the timezone as a data type parameter and throws an error if - # it's part of the timestamp string - ts_without_tz = ( - datetime.datetime.fromisoformat(ts_string).replace(tzinfo=None).isoformat(sep=" ") - ) - ts = exp.Literal.string(ts_without_tz) - - # Non-nullable DateTime64 with microsecond precision - expressions = [exp.DataTypeParam(this=tz)] if tz else [] - datatype = exp.DataType.build( - exp.DataType.Type.DATETIME64, - expressions=[exp.DataTypeParam(this=exp.Literal.number(6)), *expressions], - nullable=False, - ) - - return self.sql(exp.cast(ts, datatype, dialect=self.dialect)) - - -def _map_sql(self: ClickHouse.Generator, expression: exp.Map | exp.VarMap) -> str: - if not (expression.parent and expression.parent.arg_key == "settings"): - return _lower_func(var_map_sql(self, expression)) - - keys = expression.args.get("keys") - values = expression.args.get("values") - - if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): - self.unsupported("Cannot convert array columns into map.") - return "" - - args = [] - for key, value in zip(keys.expressions, values.expressions): - args.append(f"{self.sql(key)}: {self.sql(value)}") - - csv_args = ", ".join(args) - - return f"{{{csv_args}}}" - - -class ClickHouse(Dialect): - NORMALIZE_FUNCTIONS: bool | str = False - NULL_ORDERING = "nulls_are_last" - SUPPORTS_USER_DEFINED_TYPES = False - SAFE_DIVISION = True - LOG_BASE_FIRST: t.Optional[bool] = None - FORCE_EARLY_ALIAS_REF_EXPANSION = True - PRESERVE_ORIGINAL_NAMES = True - NUMBERS_CAN_BE_UNDERSCORE_SEPARATED = True - IDENTIFIERS_CAN_START_WITH_DIGIT = True - HEX_STRING_IS_INTEGER_TYPE = True - - # https://github.com/ClickHouse/ClickHouse/issues/33935#issue-1112165779 - NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_SENSITIVE - - UNESCAPED_SEQUENCES = { - "\\0": "\0", - } - - CREATABLE_KIND_MAPPING = {"DATABASE": "SCHEMA"} - - SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { - exp.Except: False, - exp.Intersect: False, - exp.Union: None, - } - - def generate_values_aliases(self, expression: exp.Values) -> t.List[exp.Identifier]: - # Clickhouse allows VALUES to have an embedded structure e.g: - # VALUES('person String, place String', ('Noah', 'Paris'), ...) - # In this case, we don't want to qualify the columns - values = expression.expressions[0].expressions - - structure = ( - values[0] - if (len(values) > 1 and values[0].is_string and isinstance(values[1], exp.Tuple)) - else None - ) - if structure: - # Split each column definition into the column name e.g: - # 'person String, place String' -> ['person', 'place'] - structure_coldefs = [coldef.strip() for coldef in structure.name.split(",")] - column_aliases = [ - exp.to_identifier(coldef.split(" ")[0]) for coldef in structure_coldefs - ] - else: - # Default column aliases in CH are "c1", "c2", etc. - column_aliases = [ - exp.to_identifier(f"c{i + 1}") for i in range(len(values[0].expressions)) - ] - - return column_aliases - - class Tokenizer(tokens.Tokenizer): - COMMENTS = ["--", "#", "#!", ("/*", "*/")] - IDENTIFIERS = ['"', "`"] - IDENTIFIER_ESCAPES = ["\\"] - STRING_ESCAPES = ["'", "\\"] - BIT_STRINGS = [("0b", "")] - HEX_STRINGS = [("0x", ""), ("0X", "")] - HEREDOC_STRINGS = ["$"] - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - ".:": TokenType.DOTCOLON, - "ATTACH": TokenType.COMMAND, - "DATE32": TokenType.DATE32, - "DATETIME64": TokenType.DATETIME64, - "DICTIONARY": TokenType.DICTIONARY, - "DYNAMIC": TokenType.DYNAMIC, - "ENUM8": TokenType.ENUM8, - "ENUM16": TokenType.ENUM16, - "EXCHANGE": TokenType.COMMAND, - "FINAL": TokenType.FINAL, - "FIXEDSTRING": TokenType.FIXEDSTRING, - "FLOAT32": TokenType.FLOAT, - "FLOAT64": TokenType.DOUBLE, - "GLOBAL": TokenType.GLOBAL, - "LOWCARDINALITY": TokenType.LOWCARDINALITY, - "MAP": TokenType.MAP, - "NESTED": TokenType.NESTED, - "NOTHING": TokenType.NOTHING, - "SAMPLE": TokenType.TABLE_SAMPLE, - "TUPLE": TokenType.STRUCT, - "UINT16": TokenType.USMALLINT, - "UINT32": TokenType.UINT, - "UINT64": TokenType.UBIGINT, - "UINT8": TokenType.UTINYINT, - "IPV4": TokenType.IPV4, - "IPV6": TokenType.IPV6, - "POINT": TokenType.POINT, - "RING": TokenType.RING, - "LINESTRING": TokenType.LINESTRING, - "MULTILINESTRING": TokenType.MULTILINESTRING, - "POLYGON": TokenType.POLYGON, - "MULTIPOLYGON": TokenType.MULTIPOLYGON, - "AGGREGATEFUNCTION": TokenType.AGGREGATEFUNCTION, - "SIMPLEAGGREGATEFUNCTION": TokenType.SIMPLEAGGREGATEFUNCTION, - "SYSTEM": TokenType.COMMAND, - "PREWHERE": TokenType.PREWHERE, - } - KEYWORDS.pop("/*+") - - SINGLE_TOKENS = { - **tokens.Tokenizer.SINGLE_TOKENS, - "$": TokenType.HEREDOC_STRING, - } - - class Parser(parser.Parser): - # Tested in ClickHouse's playground, it seems that the following two queries do the same thing - # * select x from t1 union all select x from t2 limit 1; - # * select x from t1 union all (select x from t2 limit 1); - MODIFIERS_ATTACHED_TO_SET_OP = False - INTERVAL_SPANS = False - OPTIONAL_ALIAS_TOKEN_CTE = False - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "ANY": exp.AnyValue.from_arg_list, - "ARRAYSUM": exp.ArraySum.from_arg_list, - "COUNTIF": _build_count_if, - "DATE_ADD": build_date_delta(exp.DateAdd, default_unit=None), - "DATEADD": build_date_delta(exp.DateAdd, default_unit=None), - "DATE_DIFF": build_date_delta(exp.DateDiff, default_unit=None, supports_timezone=True), - "DATEDIFF": build_date_delta(exp.DateDiff, default_unit=None, supports_timezone=True), - "DATE_FORMAT": _build_date_format, - "DATE_SUB": build_date_delta(exp.DateSub, default_unit=None), - "DATESUB": build_date_delta(exp.DateSub, default_unit=None), - "FORMATDATETIME": _build_date_format, - "JSONEXTRACTSTRING": build_json_extract_path( - exp.JSONExtractScalar, zero_based_indexing=False - ), - "LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True), - "MAP": parser.build_var_map, - "MATCH": exp.RegexpLike.from_arg_list, - "RANDCANONICAL": exp.Rand.from_arg_list, - "STR_TO_DATE": _build_str_to_date, - "TUPLE": exp.Struct.from_arg_list, - "TIMESTAMP_SUB": build_date_delta(exp.TimestampSub, default_unit=None), - "TIMESTAMPSUB": build_date_delta(exp.TimestampSub, default_unit=None), - "TIMESTAMP_ADD": build_date_delta(exp.TimestampAdd, default_unit=None), - "TIMESTAMPADD": build_date_delta(exp.TimestampAdd, default_unit=None), - "UNIQ": exp.ApproxDistinct.from_arg_list, - "XOR": lambda args: exp.Xor(expressions=args), - "MD5": exp.MD5Digest.from_arg_list, - "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), - "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)), - "EDITDISTANCE": exp.Levenshtein.from_arg_list, - "LEVENSHTEINDISTANCE": exp.Levenshtein.from_arg_list, - } - FUNCTIONS.pop("TRANSFORM") - - AGG_FUNCTIONS = { - "count", - "min", - "max", - "sum", - "avg", - "any", - "stddevPop", - "stddevSamp", - "varPop", - "varSamp", - "corr", - "covarPop", - "covarSamp", - "entropy", - "exponentialMovingAverage", - "intervalLengthSum", - "kolmogorovSmirnovTest", - "mannWhitneyUTest", - "median", - "rankCorr", - "sumKahan", - "studentTTest", - "welchTTest", - "anyHeavy", - "anyLast", - "boundingRatio", - "first_value", - "last_value", - "argMin", - "argMax", - "avgWeighted", - "topK", - "topKWeighted", - "deltaSum", - "deltaSumTimestamp", - "groupArray", - "groupArrayLast", - "groupUniqArray", - "groupArrayInsertAt", - "groupArrayMovingAvg", - "groupArrayMovingSum", - "groupArraySample", - "groupBitAnd", - "groupBitOr", - "groupBitXor", - "groupBitmap", - "groupBitmapAnd", - "groupBitmapOr", - "groupBitmapXor", - "sumWithOverflow", - "sumMap", - "minMap", - "maxMap", - "skewSamp", - "skewPop", - "kurtSamp", - "kurtPop", - "uniq", - "uniqExact", - "uniqCombined", - "uniqCombined64", - "uniqHLL12", - "uniqTheta", - "quantile", - "quantiles", - "quantileExact", - "quantilesExact", - "quantileExactLow", - "quantilesExactLow", - "quantileExactHigh", - "quantilesExactHigh", - "quantileExactWeighted", - "quantilesExactWeighted", - "quantileTiming", - "quantilesTiming", - "quantileTimingWeighted", - "quantilesTimingWeighted", - "quantileDeterministic", - "quantilesDeterministic", - "quantileTDigest", - "quantilesTDigest", - "quantileTDigestWeighted", - "quantilesTDigestWeighted", - "quantileBFloat16", - "quantilesBFloat16", - "quantileBFloat16Weighted", - "quantilesBFloat16Weighted", - "simpleLinearRegression", - "stochasticLinearRegression", - "stochasticLogisticRegression", - "categoricalInformationValue", - "contingency", - "cramersV", - "cramersVBiasCorrected", - "theilsU", - "maxIntersections", - "maxIntersectionsPosition", - "meanZTest", - "quantileInterpolatedWeighted", - "quantilesInterpolatedWeighted", - "quantileGK", - "quantilesGK", - "sparkBar", - "sumCount", - "largestTriangleThreeBuckets", - "histogram", - "sequenceMatch", - "sequenceCount", - "windowFunnel", - "retention", - "uniqUpTo", - "sequenceNextNode", - "exponentialTimeDecayedAvg", - } - - AGG_FUNCTIONS_SUFFIXES = [ - "If", - "Array", - "ArrayIf", - "Map", - "SimpleState", - "State", - "Merge", - "MergeState", - "ForEach", - "Distinct", - "OrDefault", - "OrNull", - "Resample", - "ArgMin", - "ArgMax", - ] - - FUNC_TOKENS = { - *parser.Parser.FUNC_TOKENS, - TokenType.AND, - TokenType.OR, - TokenType.SET, - } - - RESERVED_TOKENS = parser.Parser.RESERVED_TOKENS - {TokenType.SELECT} - - ID_VAR_TOKENS = { - *parser.Parser.ID_VAR_TOKENS, - TokenType.LIKE, - } - - AGG_FUNC_MAPPING = ( - lambda functions, suffixes: { - f"{f}{sfx}": (f, sfx) for sfx in (suffixes + [""]) for f in functions - } - )(AGG_FUNCTIONS, AGG_FUNCTIONS_SUFFIXES) - - FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "TUPLE"} - - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, - "ARRAYJOIN": lambda self: self.expression(exp.Explode, this=self._parse_expression()), - "QUANTILE": lambda self: self._parse_quantile(), - "MEDIAN": lambda self: self._parse_quantile(), - "COLUMNS": lambda self: self._parse_columns(), - } - - FUNCTION_PARSERS.pop("MATCH") - - PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, - "ENGINE": lambda self: self._parse_engine_property(), - } - PROPERTY_PARSERS.pop("DYNAMIC") - - NO_PAREN_FUNCTION_PARSERS = parser.Parser.NO_PAREN_FUNCTION_PARSERS.copy() - NO_PAREN_FUNCTION_PARSERS.pop("ANY") - - NO_PAREN_FUNCTIONS = parser.Parser.NO_PAREN_FUNCTIONS.copy() - NO_PAREN_FUNCTIONS.pop(TokenType.CURRENT_TIMESTAMP) - - RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, - TokenType.GLOBAL: lambda self, this: self._parse_global_in(this), - } - - # The PLACEHOLDER entry is popped because 1) it doesn't affect Clickhouse (it corresponds to - # the postgres-specific JSONBContains parser) and 2) it makes parsing the ternary op simpler. - COLUMN_OPERATORS = parser.Parser.COLUMN_OPERATORS.copy() - COLUMN_OPERATORS.pop(TokenType.PLACEHOLDER) - - JOIN_KINDS = { - *parser.Parser.JOIN_KINDS, - TokenType.ANY, - TokenType.ASOF, - TokenType.ARRAY, - } - - TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { - TokenType.ANY, - TokenType.ARRAY, - TokenType.FINAL, - TokenType.FORMAT, - TokenType.SETTINGS, - } - - ALIAS_TOKENS = parser.Parser.ALIAS_TOKENS - { - TokenType.FORMAT, - } - - LOG_DEFAULTS_TO_LN = True - - QUERY_MODIFIER_PARSERS = { - **parser.Parser.QUERY_MODIFIER_PARSERS, - TokenType.SETTINGS: lambda self: ( - "settings", - self._advance() or self._parse_csv(self._parse_assignment), - ), - TokenType.FORMAT: lambda self: ("format", self._advance() or self._parse_id_var()), - } - - CONSTRAINT_PARSERS = { - **parser.Parser.CONSTRAINT_PARSERS, - "INDEX": lambda self: self._parse_index_constraint(), - "CODEC": lambda self: self._parse_compress(), - } - - ALTER_PARSERS = { - **parser.Parser.ALTER_PARSERS, - "REPLACE": lambda self: self._parse_alter_table_replace(), - } - - SCHEMA_UNNAMED_CONSTRAINTS = { - *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS, - "INDEX", - } - - PLACEHOLDER_PARSERS = { - **parser.Parser.PLACEHOLDER_PARSERS, - TokenType.L_BRACE: lambda self: self._parse_query_parameter(), - } - - def _parse_engine_property(self) -> exp.EngineProperty: - self._match(TokenType.EQ) - return self.expression( - exp.EngineProperty, - this=self._parse_field(any_token=True, anonymous_func=True), - ) - - # https://clickhouse.com/docs/en/sql-reference/statements/create/function - def _parse_user_defined_function_expression(self) -> t.Optional[exp.Expression]: - return self._parse_lambda() - - def _parse_types( - self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True - ) -> t.Optional[exp.Expression]: - dtype = super()._parse_types( - check_func=check_func, schema=schema, allow_identifiers=allow_identifiers - ) - if isinstance(dtype, exp.DataType) and dtype.args.get("nullable") is not True: - # Mark every type as non-nullable which is ClickHouse's default, unless it's - # already marked as nullable. This marker helps us transpile types from other - # dialects to ClickHouse, so that we can e.g. produce `CAST(x AS Nullable(String))` - # from `CAST(x AS TEXT)`. If there is a `NULL` value in `x`, the former would - # fail in ClickHouse without the `Nullable` type constructor. - dtype.set("nullable", False) - - return dtype - - def _parse_extract(self) -> exp.Extract | exp.Anonymous: - index = self._index - this = self._parse_bitwise() - if self._match(TokenType.FROM): - self._retreat(index) - return super()._parse_extract() - - # We return Anonymous here because extract and regexpExtract have different semantics, - # so parsing extract(foo, bar) into RegexpExtract can potentially break queries. E.g., - # `extract('foobar', 'b')` works, but ClickHouse crashes for `regexpExtract('foobar', 'b')`. - # - # TODO: can we somehow convert the former into an equivalent `regexpExtract` call? - self._match(TokenType.COMMA) - return self.expression( - exp.Anonymous, this="extract", expressions=[this, self._parse_bitwise()] - ) - - def _parse_assignment(self) -> t.Optional[exp.Expression]: - this = super()._parse_assignment() - - if self._match(TokenType.PLACEHOLDER): - return self.expression( - exp.If, - this=this, - true=self._parse_assignment(), - false=self._match(TokenType.COLON) and self._parse_assignment(), - ) - - return this - - def _parse_query_parameter(self) -> t.Optional[exp.Expression]: - """ - Parse a placeholder expression like SELECT {abc: UInt32} or FROM {table: Identifier} - https://clickhouse.com/docs/en/sql-reference/syntax#defining-and-using-query-parameters - """ - index = self._index - - this = self._parse_id_var() - self._match(TokenType.COLON) - kind = self._parse_types(check_func=False, allow_identifiers=False) or ( - self._match_text_seq("IDENTIFIER") and "Identifier" - ) - - if not kind: - self._retreat(index) - return None - elif not self._match(TokenType.R_BRACE): - self.raise_error("Expecting }") - - if isinstance(this, exp.Identifier) and not this.quoted: - this = exp.var(this.name) - - return self.expression(exp.Placeholder, this=this, kind=kind) - - def _parse_bracket( - self, this: t.Optional[exp.Expression] = None - ) -> t.Optional[exp.Expression]: - l_brace = self._match(TokenType.L_BRACE, advance=False) - bracket = super()._parse_bracket(this) - - if l_brace and isinstance(bracket, exp.Struct): - varmap = exp.VarMap(keys=exp.Array(), values=exp.Array()) - for expression in bracket.expressions: - if not isinstance(expression, exp.PropertyEQ): - break - - varmap.args["keys"].append("expressions", exp.Literal.string(expression.name)) - varmap.args["values"].append("expressions", expression.expression) - - return varmap - - return bracket - - def _parse_in(self, this: t.Optional[exp.Expression], is_global: bool = False) -> exp.In: - this = super()._parse_in(this) - this.set("is_global", is_global) - return this - - def _parse_global_in(self, this: t.Optional[exp.Expression]) -> exp.Not | exp.In: - is_negated = self._match(TokenType.NOT) - this = self._match(TokenType.IN) and self._parse_in(this, is_global=True) - return self.expression(exp.Not, this=this) if is_negated else this - - def _parse_table( - self, - schema: bool = False, - joins: bool = False, - alias_tokens: t.Optional[t.Collection[TokenType]] = None, - parse_bracket: bool = False, - is_db_reference: bool = False, - parse_partition: bool = False, - ) -> t.Optional[exp.Expression]: - this = super()._parse_table( - schema=schema, - joins=joins, - alias_tokens=alias_tokens, - parse_bracket=parse_bracket, - is_db_reference=is_db_reference, - ) - - if isinstance(this, exp.Table): - inner = this.this - alias = this.args.get("alias") - - if isinstance(inner, exp.GenerateSeries) and alias and not alias.columns: - alias.set("columns", [exp.to_identifier("generate_series")]) - - if self._match(TokenType.FINAL): - this = self.expression(exp.Final, this=this) - - return this - - def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition: - return super()._parse_position(haystack_first=True) - - # https://clickhouse.com/docs/en/sql-reference/statements/select/with/ - def _parse_cte(self) -> t.Optional[exp.CTE]: - # WITH AS - cte: t.Optional[exp.CTE] = self._try_parse(super()._parse_cte) - - if not cte: - # WITH AS - cte = self.expression( - exp.CTE, - this=self._parse_assignment(), - alias=self._parse_table_alias(), - scalar=True, - ) - - return cte - - def _parse_join_parts( - self, - ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: - is_global = self._match(TokenType.GLOBAL) and self._prev - kind_pre = self._match_set(self.JOIN_KINDS, advance=False) and self._prev - - if kind_pre: - kind = self._match_set(self.JOIN_KINDS) and self._prev - side = self._match_set(self.JOIN_SIDES) and self._prev - return is_global, side, kind - - return ( - is_global, - self._match_set(self.JOIN_SIDES) and self._prev, - self._match_set(self.JOIN_KINDS) and self._prev, - ) - - def _parse_join( - self, skip_join_token: bool = False, parse_bracket: bool = False - ) -> t.Optional[exp.Join]: - join = super()._parse_join(skip_join_token=skip_join_token, parse_bracket=True) - if join: - join.set("global", join.args.pop("method", None)) - - # tbl ARRAY JOIN arr <-- this should be a `Column` reference, not a `Table` - # https://clickhouse.com/docs/en/sql-reference/statements/select/array-join - if join.kind == "ARRAY": - for table in join.find_all(exp.Table): - table.replace(table.to_column()) - - return join - - def _parse_function( - self, - functions: t.Optional[t.Dict[str, t.Callable]] = None, - anonymous: bool = False, - optional_parens: bool = True, - any_token: bool = False, - ) -> t.Optional[exp.Expression]: - expr = super()._parse_function( - functions=functions, - anonymous=anonymous, - optional_parens=optional_parens, - any_token=any_token, - ) - - func = expr.this if isinstance(expr, exp.Window) else expr - - # Aggregate functions can be split in 2 parts: - parts = ( - self.AGG_FUNC_MAPPING.get(func.this) if isinstance(func, exp.Anonymous) else None - ) - - if parts: - anon_func: exp.Anonymous = t.cast(exp.Anonymous, func) - params = self._parse_func_params(anon_func) - - kwargs = { - "this": anon_func.this, - "expressions": anon_func.expressions, - } - if parts[1]: - exp_class: t.Type[exp.Expression] = ( - exp.CombinedParameterizedAgg if params else exp.CombinedAggFunc - ) - else: - exp_class = exp.ParameterizedAgg if params else exp.AnonymousAggFunc - - kwargs["exp_class"] = exp_class - if params: - kwargs["params"] = params - - func = self.expression(**kwargs) - - if isinstance(expr, exp.Window): - # The window's func was parsed as Anonymous in base parser, fix its - # type to be ClickHouse style CombinedAnonymousAggFunc / AnonymousAggFunc - expr.set("this", func) - elif params: - # Params have blocked super()._parse_function() from parsing the following window - # (if that exists) as they're standing between the function call and the window spec - expr = self._parse_window(func) - else: - expr = func - - return expr - - def _parse_func_params( - self, this: t.Optional[exp.Func] = None - ) -> t.Optional[t.List[exp.Expression]]: - if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): - return self._parse_csv(self._parse_lambda) - - if self._match(TokenType.L_PAREN): - params = self._parse_csv(self._parse_lambda) - self._match_r_paren(this) - return params - - return None - - def _parse_quantile(self) -> exp.Quantile: - this = self._parse_lambda() - params = self._parse_func_params() - if params: - return self.expression(exp.Quantile, this=params[0], quantile=this) - return self.expression(exp.Quantile, this=this, quantile=exp.Literal.number(0.5)) - - def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]: - return super()._parse_wrapped_id_vars(optional=True) - - def _parse_primary_key( - self, wrapped_optional: bool = False, in_props: bool = False - ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey: - return super()._parse_primary_key( - wrapped_optional=wrapped_optional or in_props, in_props=in_props - ) - - def _parse_on_property(self) -> t.Optional[exp.Expression]: - index = self._index - if self._match_text_seq("CLUSTER"): - this = self._parse_string() or self._parse_id_var() - if this: - return self.expression(exp.OnCluster, this=this) - else: - self._retreat(index) - return None - - def _parse_index_constraint( - self, kind: t.Optional[str] = None - ) -> exp.IndexColumnConstraint: - # INDEX name1 expr TYPE type1(args) GRANULARITY value - this = self._parse_id_var() - expression = self._parse_assignment() - - index_type = self._match_text_seq("TYPE") and ( - self._parse_function() or self._parse_var() - ) - - granularity = self._match_text_seq("GRANULARITY") and self._parse_term() - - return self.expression( - exp.IndexColumnConstraint, - this=this, - expression=expression, - index_type=index_type, - granularity=granularity, - ) - - def _parse_partition(self) -> t.Optional[exp.Partition]: - # https://clickhouse.com/docs/en/sql-reference/statements/alter/partition#how-to-set-partition-expression - if not self._match(TokenType.PARTITION): - return None - - if self._match_text_seq("ID"): - # Corresponds to the PARTITION ID syntax - expressions: t.List[exp.Expression] = [ - self.expression(exp.PartitionId, this=self._parse_string()) - ] - else: - expressions = self._parse_expressions() - - return self.expression(exp.Partition, expressions=expressions) - - def _parse_alter_table_replace(self) -> t.Optional[exp.Expression]: - partition = self._parse_partition() - - if not partition or not self._match(TokenType.FROM): - return None - - return self.expression( - exp.ReplacePartition, expression=partition, source=self._parse_table_parts() - ) - - def _parse_projection_def(self) -> t.Optional[exp.ProjectionDef]: - if not self._match_text_seq("PROJECTION"): - return None - - return self.expression( - exp.ProjectionDef, - this=self._parse_id_var(), - expression=self._parse_wrapped(self._parse_statement), - ) - - def _parse_constraint(self) -> t.Optional[exp.Expression]: - return super()._parse_constraint() or self._parse_projection_def() - - def _parse_alias( - self, this: t.Optional[exp.Expression], explicit: bool = False - ) -> t.Optional[exp.Expression]: - # In clickhouse "SELECT APPLY(...)" is a query modifier, - # so "APPLY" shouldn't be parsed as 's alias. However, "SELECT apply" is a valid alias - if self._match_pair(TokenType.APPLY, TokenType.L_PAREN, advance=False): - return this - - return super()._parse_alias(this=this, explicit=explicit) - - def _parse_expression(self) -> t.Optional[exp.Expression]: - this = super()._parse_expression() - - # Clickhouse allows "SELECT [APPLY(func)] [...]]" modifier - while self._match_pair(TokenType.APPLY, TokenType.L_PAREN): - this = exp.Apply(this=this, expression=self._parse_var(any_token=True)) - self._match(TokenType.R_PAREN) - - return this - - def _parse_columns(self) -> exp.Expression: - this: exp.Expression = self.expression(exp.Columns, this=self._parse_lambda()) - - while self._next and self._match_text_seq(")", "APPLY", "("): - self._match(TokenType.R_PAREN) - this = exp.Apply(this=this, expression=self._parse_var(any_token=True)) - return this - - def _parse_value(self, values: bool = True) -> t.Optional[exp.Tuple]: - value = super()._parse_value(values=values) - if not value: - return None - - # In Clickhouse "SELECT * FROM VALUES (1, 2, 3)" generates a table with a single column, in contrast - # to other dialects. For this case, we canonicalize the values into a tuple-of-tuples AST if it's not already one. - # In INSERT INTO statements the same clause actually references multiple columns (opposite semantics), - # but the final result is not altered by the extra parentheses. - # Note: Clickhouse allows VALUES([structure], value, ...) so the branch checks for the last expression - expressions = value.expressions - if values and not isinstance(expressions[-1], exp.Tuple): - value.set( - "expressions", - [self.expression(exp.Tuple, expressions=[expr]) for expr in expressions], - ) - - return value - - class Generator(generator.Generator): - QUERY_HINTS = False - STRUCT_DELIMITER = ("(", ")") - NVL2_SUPPORTED = False - TABLESAMPLE_REQUIRES_PARENS = False - TABLESAMPLE_SIZE_IS_ROWS = False - TABLESAMPLE_KEYWORDS = "SAMPLE" - LAST_DAY_SUPPORTS_DATE_PART = False - CAN_IMPLEMENT_ARRAY_ANY = True - SUPPORTS_TO_NUMBER = False - JOIN_HINTS = False - TABLE_HINTS = False - GROUPINGS_SEP = "" - SET_OP_MODIFIERS = False - ARRAY_SIZE_NAME = "LENGTH" - WRAP_DERIVED_VALUES = False - - STRING_TYPE_MAPPING = { - exp.DataType.Type.BLOB: "String", - exp.DataType.Type.CHAR: "String", - exp.DataType.Type.LONGBLOB: "String", - exp.DataType.Type.LONGTEXT: "String", - exp.DataType.Type.MEDIUMBLOB: "String", - exp.DataType.Type.MEDIUMTEXT: "String", - exp.DataType.Type.TINYBLOB: "String", - exp.DataType.Type.TINYTEXT: "String", - exp.DataType.Type.TEXT: "String", - exp.DataType.Type.VARBINARY: "String", - exp.DataType.Type.VARCHAR: "String", - } - - SUPPORTED_JSON_PATH_PARTS = { - exp.JSONPathKey, - exp.JSONPathRoot, - exp.JSONPathSubscript, - } - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - **STRING_TYPE_MAPPING, - exp.DataType.Type.ARRAY: "Array", - exp.DataType.Type.BOOLEAN: "Bool", - exp.DataType.Type.BIGINT: "Int64", - exp.DataType.Type.DATE32: "Date32", - exp.DataType.Type.DATETIME: "DateTime", - exp.DataType.Type.DATETIME2: "DateTime", - exp.DataType.Type.SMALLDATETIME: "DateTime", - exp.DataType.Type.DATETIME64: "DateTime64", - exp.DataType.Type.DECIMAL: "Decimal", - exp.DataType.Type.DECIMAL32: "Decimal32", - exp.DataType.Type.DECIMAL64: "Decimal64", - exp.DataType.Type.DECIMAL128: "Decimal128", - exp.DataType.Type.DECIMAL256: "Decimal256", - exp.DataType.Type.TIMESTAMP: "DateTime", - exp.DataType.Type.TIMESTAMPNTZ: "DateTime", - exp.DataType.Type.TIMESTAMPTZ: "DateTime", - exp.DataType.Type.DOUBLE: "Float64", - exp.DataType.Type.ENUM: "Enum", - exp.DataType.Type.ENUM8: "Enum8", - exp.DataType.Type.ENUM16: "Enum16", - exp.DataType.Type.FIXEDSTRING: "FixedString", - exp.DataType.Type.FLOAT: "Float32", - exp.DataType.Type.INT: "Int32", - exp.DataType.Type.MEDIUMINT: "Int32", - exp.DataType.Type.INT128: "Int128", - exp.DataType.Type.INT256: "Int256", - exp.DataType.Type.LOWCARDINALITY: "LowCardinality", - exp.DataType.Type.MAP: "Map", - exp.DataType.Type.NESTED: "Nested", - exp.DataType.Type.NOTHING: "Nothing", - exp.DataType.Type.SMALLINT: "Int16", - exp.DataType.Type.STRUCT: "Tuple", - exp.DataType.Type.TINYINT: "Int8", - exp.DataType.Type.UBIGINT: "UInt64", - exp.DataType.Type.UINT: "UInt32", - exp.DataType.Type.UINT128: "UInt128", - exp.DataType.Type.UINT256: "UInt256", - exp.DataType.Type.USMALLINT: "UInt16", - exp.DataType.Type.UTINYINT: "UInt8", - exp.DataType.Type.IPV4: "IPv4", - exp.DataType.Type.IPV6: "IPv6", - exp.DataType.Type.POINT: "Point", - exp.DataType.Type.RING: "Ring", - exp.DataType.Type.LINESTRING: "LineString", - exp.DataType.Type.MULTILINESTRING: "MultiLineString", - exp.DataType.Type.POLYGON: "Polygon", - exp.DataType.Type.MULTIPOLYGON: "MultiPolygon", - exp.DataType.Type.AGGREGATEFUNCTION: "AggregateFunction", - exp.DataType.Type.SIMPLEAGGREGATEFUNCTION: "SimpleAggregateFunction", - exp.DataType.Type.DYNAMIC: "Dynamic", - } - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.AnyValue: rename_func("any"), - exp.ApproxDistinct: rename_func("uniq"), - exp.ArrayConcat: rename_func("arrayConcat"), - exp.ArrayFilter: lambda self, e: self.func("arrayFilter", e.expression, e.this), - exp.ArrayRemove: remove_from_array_using_filter, - exp.ArraySum: rename_func("arraySum"), - exp.ArgMax: arg_max_or_min_no_count("argMax"), - exp.ArgMin: arg_max_or_min_no_count("argMin"), - exp.Array: inline_array_sql, - exp.CastToStrType: rename_func("CAST"), - exp.CountIf: rename_func("countIf"), - exp.CompressColumnConstraint: lambda self, - e: f"CODEC({self.expressions(e, key='this', flat=True)})", - exp.ComputedColumnConstraint: lambda self, - e: f"{'MATERIALIZED' if e.args.get('persisted') else 'ALIAS'} {self.sql(e, 'this')}", - exp.CurrentDate: lambda self, e: self.func("CURRENT_DATE"), - exp.DateAdd: _datetime_delta_sql("DATE_ADD"), - exp.DateDiff: _datetime_delta_sql("DATE_DIFF"), - exp.DateStrToDate: rename_func("toDate"), - exp.DateSub: _datetime_delta_sql("DATE_SUB"), - exp.Explode: rename_func("arrayJoin"), - exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", - exp.IsNan: rename_func("isNaN"), - exp.JSONCast: lambda self, e: f"{self.sql(e, 'this')}.:{self.sql(e, 'to')}", - exp.JSONExtract: json_extract_segments("JSONExtractString", quoted_index=False), - exp.JSONExtractScalar: json_extract_segments("JSONExtractString", quoted_index=False), - exp.JSONPathKey: json_path_key_only_name, - exp.JSONPathRoot: lambda *_: "", - exp.Length: length_or_char_length_sql, - exp.Map: _map_sql, - exp.Median: rename_func("median"), - exp.Nullif: rename_func("nullIf"), - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", - exp.Pivot: no_pivot_sql, - exp.Quantile: _quantile_sql, - exp.RegexpLike: lambda self, e: self.func("match", e.this, e.expression), - exp.Rand: rename_func("randCanonical"), - exp.StartsWith: rename_func("startsWith"), - exp.EndsWith: rename_func("endsWith"), - exp.StrPosition: lambda self, e: strposition_sql( - self, - e, - func_name="POSITION", - supports_position=True, - use_ansi_position=False, - ), - exp.TimeToStr: lambda self, e: self.func( - "formatDateTime", e.this, self.format_time(e), e.args.get("zone") - ), - exp.TimeStrToTime: _timestrtotime_sql, - exp.TimestampAdd: _datetime_delta_sql("TIMESTAMP_ADD"), - exp.TimestampSub: _datetime_delta_sql("TIMESTAMP_SUB"), - exp.VarMap: _map_sql, - exp.Xor: lambda self, e: self.func("xor", e.this, e.expression, *e.expressions), - exp.MD5Digest: rename_func("MD5"), - exp.MD5: lambda self, e: self.func("LOWER", self.func("HEX", self.func("MD5", e.this))), - exp.SHA: rename_func("SHA1"), - exp.SHA2: sha256_sql, - exp.UnixToTime: _unix_to_time_sql, - exp.TimestampTrunc: timestamptrunc_sql(zone=True), - exp.Trim: lambda self, e: trim_sql(self, e, default_trim_type="BOTH"), - exp.Variance: rename_func("varSamp"), - exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), - exp.Stddev: rename_func("stddevSamp"), - exp.Chr: rename_func("CHAR"), - exp.Lag: lambda self, e: self.func( - "lagInFrame", e.this, e.args.get("offset"), e.args.get("default") - ), - exp.Lead: lambda self, e: self.func( - "leadInFrame", e.this, e.args.get("offset"), e.args.get("default") - ), - exp.Levenshtein: unsupported_args("ins_cost", "del_cost", "sub_cost", "max_dist")( - rename_func("editDistance") - ), - } - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.OnCluster: exp.Properties.Location.POST_NAME, - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, - exp.ToTableProperty: exp.Properties.Location.POST_NAME, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, - } - - # There's no list in docs, but it can be found in Clickhouse code - # see `ClickHouse/src/Parsers/ParserCreate*.cpp` - ON_CLUSTER_TARGETS = { - "SCHEMA", # Transpiled CREATE SCHEMA may have OnCluster property set - "DATABASE", - "TABLE", - "VIEW", - "DICTIONARY", - "INDEX", - "FUNCTION", - "NAMED COLLECTION", - } - - # https://clickhouse.com/docs/en/sql-reference/data-types/nullable - NON_NULLABLE_TYPES = { - exp.DataType.Type.ARRAY, - exp.DataType.Type.MAP, - exp.DataType.Type.STRUCT, - exp.DataType.Type.POINT, - exp.DataType.Type.RING, - exp.DataType.Type.LINESTRING, - exp.DataType.Type.MULTILINESTRING, - exp.DataType.Type.POLYGON, - exp.DataType.Type.MULTIPOLYGON, - } - - def strtodate_sql(self, expression: exp.StrToDate) -> str: - strtodate_sql = self.function_fallback_sql(expression) - - if not isinstance(expression.parent, exp.Cast): - # StrToDate returns DATEs in other dialects (eg. postgres), so - # this branch aims to improve the transpilation to clickhouse - return self.cast_sql(exp.cast(expression, "DATE")) - - return strtodate_sql - - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - this = expression.this - - if isinstance(this, exp.StrToDate) and expression.to == exp.DataType.build("datetime"): - return self.sql(this) - - return super().cast_sql(expression, safe_prefix=safe_prefix) - - def trycast_sql(self, expression: exp.TryCast) -> str: - dtype = expression.to - if not dtype.is_type(*self.NON_NULLABLE_TYPES, check_nullable=True): - # Casting x into Nullable(T) appears to behave similarly to TRY_CAST(x AS T) - dtype.set("nullable", True) - - return super().cast_sql(expression) - - def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str: - this = self.json_path_part(expression.this) - return str(int(this) + 1) if is_int(this) else this - - def likeproperty_sql(self, expression: exp.LikeProperty) -> str: - return f"AS {self.sql(expression, 'this')}" - - def _any_to_has( - self, - expression: exp.EQ | exp.NEQ, - default: t.Callable[[t.Any], str], - prefix: str = "", - ) -> str: - if isinstance(expression.left, exp.Any): - arr = expression.left - this = expression.right - elif isinstance(expression.right, exp.Any): - arr = expression.right - this = expression.left - else: - return default(expression) - - return prefix + self.func("has", arr.this.unnest(), this) - - def eq_sql(self, expression: exp.EQ) -> str: - return self._any_to_has(expression, super().eq_sql) - - def neq_sql(self, expression: exp.NEQ) -> str: - return self._any_to_has(expression, super().neq_sql, "NOT ") - - def regexpilike_sql(self, expression: exp.RegexpILike) -> str: - # Manually add a flag to make the search case-insensitive - regex = self.func("CONCAT", "'(?i)'", expression.expression) - return self.func("match", expression.this, regex) - - def datatype_sql(self, expression: exp.DataType) -> str: - # String is the standard ClickHouse type, every other variant is just an alias. - # Additionally, any supplied length parameter will be ignored. - # - # https://clickhouse.com/docs/en/sql-reference/data-types/string - if expression.this in self.STRING_TYPE_MAPPING: - dtype = "String" - else: - dtype = super().datatype_sql(expression) - - # This section changes the type to `Nullable(...)` if the following conditions hold: - # - It's marked as nullable - this ensures we won't wrap ClickHouse types with `Nullable` - # and change their semantics - # - It's not the key type of a `Map`. This is because ClickHouse enforces the following - # constraint: "Type of Map key must be a type, that can be represented by integer or - # String or FixedString (possibly LowCardinality) or UUID or IPv6" - # - It's not a composite type, e.g. `Nullable(Array(...))` is not a valid type - parent = expression.parent - nullable = expression.args.get("nullable") - if nullable is True or ( - nullable is None - and not ( - isinstance(parent, exp.DataType) - and parent.is_type(exp.DataType.Type.MAP, check_nullable=True) - and expression.index in (None, 0) - ) - and not expression.is_type(*self.NON_NULLABLE_TYPES, check_nullable=True) - ): - dtype = f"Nullable({dtype})" - - return dtype - - def cte_sql(self, expression: exp.CTE) -> str: - if expression.args.get("scalar"): - this = self.sql(expression, "this") - alias = self.sql(expression, "alias") - return f"{this} AS {alias}" - - return super().cte_sql(expression) - - def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: - return super().after_limit_modifiers(expression) + [ - ( - self.seg("SETTINGS ") + self.expressions(expression, key="settings", flat=True) - if expression.args.get("settings") - else "" - ), - ( - self.seg("FORMAT ") + self.sql(expression, "format") - if expression.args.get("format") - else "" - ), - ] - - def placeholder_sql(self, expression: exp.Placeholder) -> str: - return f"{{{expression.name}: {self.sql(expression, 'kind')}}}" - - def oncluster_sql(self, expression: exp.OnCluster) -> str: - return f"ON CLUSTER {self.sql(expression, 'this')}" - - def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: - if expression.kind in self.ON_CLUSTER_TARGETS and locations.get( - exp.Properties.Location.POST_NAME - ): - this_name = self.sql( - expression.this if isinstance(expression.this, exp.Schema) else expression, - "this", - ) - this_properties = " ".join( - [self.sql(prop) for prop in locations[exp.Properties.Location.POST_NAME]] - ) - this_schema = self.schema_columns_sql(expression.this) - this_schema = f"{self.sep()}{this_schema}" if this_schema else "" - - return f"{this_name}{self.sep()}{this_properties}{this_schema}" - - return super().createable_sql(expression, locations) - - def create_sql(self, expression: exp.Create) -> str: - # The comment property comes last in CTAS statements, i.e. after the query - query = expression.expression - if isinstance(query, exp.Query): - comment_prop = expression.find(exp.SchemaCommentProperty) - if comment_prop: - comment_prop.pop() - query.replace(exp.paren(query)) - else: - comment_prop = None - - create_sql = super().create_sql(expression) - - comment_sql = self.sql(comment_prop) - comment_sql = f" {comment_sql}" if comment_sql else "" - - return f"{create_sql}{comment_sql}" - - def prewhere_sql(self, expression: exp.PreWhere) -> str: - this = self.indent(self.sql(expression, "this")) - return f"{self.seg('PREWHERE')}{self.sep()}{this}" - - def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str: - this = self.sql(expression, "this") - this = f" {this}" if this else "" - expr = self.sql(expression, "expression") - expr = f" {expr}" if expr else "" - index_type = self.sql(expression, "index_type") - index_type = f" TYPE {index_type}" if index_type else "" - granularity = self.sql(expression, "granularity") - granularity = f" GRANULARITY {granularity}" if granularity else "" - - return f"INDEX{this}{expr}{index_type}{granularity}" - - def partition_sql(self, expression: exp.Partition) -> str: - return f"PARTITION {self.expressions(expression, flat=True)}" - - def partitionid_sql(self, expression: exp.PartitionId) -> str: - return f"ID {self.sql(expression.this)}" - - def replacepartition_sql(self, expression: exp.ReplacePartition) -> str: - return ( - f"REPLACE {self.sql(expression.expression)} FROM {self.sql(expression, 'source')}" - ) - - def projectiondef_sql(self, expression: exp.ProjectionDef) -> str: - return f"PROJECTION {self.sql(expression.this)} {self.wrap(expression.expression)}" - - def is_sql(self, expression: exp.Is) -> str: - is_sql = super().is_sql(expression) - - if isinstance(expression.parent, exp.Not): - # value IS NOT NULL -> NOT (value IS NULL) - is_sql = self.wrap(is_sql) - - return is_sql - - def in_sql(self, expression: exp.In) -> str: - in_sql = super().in_sql(expression) - - if isinstance(expression.parent, exp.Not) and expression.args.get("is_global"): - in_sql = in_sql.replace("GLOBAL IN", "GLOBAL NOT IN", 1) - - return in_sql - - def not_sql(self, expression: exp.Not) -> str: - if isinstance(expression.this, exp.In) and expression.this.args.get("is_global"): - # let `GLOBAL IN` child interpose `NOT` - return self.sql(expression, "this") - - return super().not_sql(expression) - - def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str: - # If the VALUES clause contains tuples of expressions, we need to treat it - # as a table since Clickhouse will automatically alias it as such. - alias = expression.args.get("alias") - - if alias and alias.args.get("columns") and expression.expressions: - values = expression.expressions[0].expressions - values_as_table = any(isinstance(value, exp.Tuple) for value in values) - else: - values_as_table = True - - return super().values_sql(expression, values_as_table=values_as_table) diff --git a/altimate_packages/sqlglot/dialects/databricks.py b/altimate_packages/sqlglot/dialects/databricks.py deleted file mode 100644 index f13b4ca44..000000000 --- a/altimate_packages/sqlglot/dialects/databricks.py +++ /dev/null @@ -1,131 +0,0 @@ -from __future__ import annotations - -from copy import deepcopy -from collections import defaultdict - -from sqlglot import exp, transforms, jsonpath -from sqlglot.dialects.dialect import ( - date_delta_sql, - build_date_delta, - timestamptrunc_sql, - build_formatted_time, -) -from sqlglot.dialects.spark import Spark -from sqlglot.tokens import TokenType -from sqlglot.optimizer.annotate_types import TypeAnnotator - - -def _jsonextract_sql( - self: Databricks.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar -) -> str: - this = self.sql(expression, "this") - expr = self.sql(expression, "expression") - return f"{this}:{expr}" - - -class Databricks(Spark): - SAFE_DIVISION = False - COPY_PARAMS_ARE_CSV = False - - COERCES_TO = defaultdict(set, deepcopy(TypeAnnotator.COERCES_TO)) - for text_type in exp.DataType.TEXT_TYPES: - COERCES_TO[text_type] |= { - *exp.DataType.NUMERIC_TYPES, - *exp.DataType.TEMPORAL_TYPES, - exp.DataType.Type.BINARY, - exp.DataType.Type.BOOLEAN, - exp.DataType.Type.INTERVAL, - } - - class JSONPathTokenizer(jsonpath.JSONPathTokenizer): - IDENTIFIERS = ["`", '"'] - - class Tokenizer(Spark.Tokenizer): - KEYWORDS = { - **Spark.Tokenizer.KEYWORDS, - "VOID": TokenType.VOID, - } - - class Parser(Spark.Parser): - LOG_DEFAULTS_TO_LN = True - STRICT_CAST = True - COLON_IS_VARIANT_EXTRACT = True - - FUNCTIONS = { - **Spark.Parser.FUNCTIONS, - "DATEADD": build_date_delta(exp.DateAdd), - "DATE_ADD": build_date_delta(exp.DateAdd), - "DATEDIFF": build_date_delta(exp.DateDiff), - "DATE_DIFF": build_date_delta(exp.DateDiff), - "TO_DATE": build_formatted_time(exp.TsOrDsToDate, "databricks"), - } - - FACTOR = { - **Spark.Parser.FACTOR, - TokenType.COLON: exp.JSONExtract, - } - - class Generator(Spark.Generator): - TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" - COPY_PARAMS_ARE_WRAPPED = False - COPY_PARAMS_EQ_REQUIRED = True - JSON_PATH_SINGLE_QUOTE_ESCAPE = False - QUOTE_JSON_PATH = False - PARSE_JSON_NAME = "PARSE_JSON" - - TRANSFORMS = { - **Spark.Generator.TRANSFORMS, - exp.DateAdd: date_delta_sql("DATEADD"), - exp.DateDiff: date_delta_sql("DATEDIFF"), - exp.DatetimeAdd: lambda self, e: self.func( - "TIMESTAMPADD", e.unit, e.expression, e.this - ), - exp.DatetimeSub: lambda self, e: self.func( - "TIMESTAMPADD", - e.unit, - exp.Mul(this=e.expression, expression=exp.Literal.number(-1)), - e.this, - ), - exp.DatetimeTrunc: timestamptrunc_sql(), - exp.Select: transforms.preprocess( - [ - transforms.eliminate_distinct_on, - transforms.unnest_to_explode, - transforms.any_to_exists, - ] - ), - exp.JSONExtract: _jsonextract_sql, - exp.JSONExtractScalar: _jsonextract_sql, - exp.JSONPathRoot: lambda *_: "", - exp.ToChar: lambda self, e: self.function_fallback_sql(e), - } - - TRANSFORMS.pop(exp.TryCast) - - TYPE_MAPPING = { - **Spark.Generator.TYPE_MAPPING, - exp.DataType.Type.NULL: "VOID", - } - - def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: - constraint = expression.find(exp.GeneratedAsIdentityColumnConstraint) - kind = expression.kind - if ( - constraint - and isinstance(kind, exp.DataType) - and kind.this in exp.DataType.INTEGER_TYPES - ): - # only BIGINT generated identity constraints are supported - expression.set("kind", exp.DataType.build("bigint")) - - return super().columndef_sql(expression, sep) - - def generatedasidentitycolumnconstraint_sql( - self, expression: exp.GeneratedAsIdentityColumnConstraint - ) -> str: - expression.set("this", True) # trigger ALWAYS in super class - return super().generatedasidentitycolumnconstraint_sql(expression) - - def jsonpath_sql(self, expression: exp.JSONPath) -> str: - expression.set("escape", None) - return super().jsonpath_sql(expression) diff --git a/altimate_packages/sqlglot/dialects/dialect.py b/altimate_packages/sqlglot/dialects/dialect.py deleted file mode 100644 index 233b585fd..000000000 --- a/altimate_packages/sqlglot/dialects/dialect.py +++ /dev/null @@ -1,1915 +0,0 @@ -from __future__ import annotations - -import importlib -import logging -import typing as t -import sys - -from enum import Enum, auto -from functools import reduce - -from sqlglot import exp -from sqlglot.dialects import DIALECT_MODULE_NAMES -from sqlglot.errors import ParseError -from sqlglot.generator import Generator, unsupported_args -from sqlglot.helper import ( - AutoName, - flatten, - is_int, - seq_get, - subclasses, - suggest_closest_match_and_fail, - to_bool, -) -from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path -from sqlglot.parser import Parser -from sqlglot.time import TIMEZONES, format_time, subsecond_precision -from sqlglot.tokens import Token, Tokenizer, TokenType -from sqlglot.trie import new_trie - -DATE_ADD_OR_DIFF = t.Union[ - exp.DateAdd, - exp.DateDiff, - exp.DateSub, - exp.TsOrDsAdd, - exp.TsOrDsDiff, -] -DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] -JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] - - -if t.TYPE_CHECKING: - from sqlglot._typing import B, E, F - - from sqlglot.optimizer.annotate_types import TypeAnnotator - - AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] - -logger = logging.getLogger("sqlglot") - -UNESCAPED_SEQUENCES = { - "\\a": "\a", - "\\b": "\b", - "\\f": "\f", - "\\n": "\n", - "\\r": "\r", - "\\t": "\t", - "\\v": "\v", - "\\\\": "\\", -} - - -def annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: - return lambda self, e: self._annotate_with_type(e, data_type) - - -class Dialects(str, Enum): - """Dialects supported by SQLGLot.""" - - DIALECT = "" - - ATHENA = "athena" - BIGQUERY = "bigquery" - CLICKHOUSE = "clickhouse" - DATABRICKS = "databricks" - DORIS = "doris" - DRILL = "drill" - DRUID = "druid" - DUCKDB = "duckdb" - DUNE = "dune" - HIVE = "hive" - MATERIALIZE = "materialize" - MYSQL = "mysql" - ORACLE = "oracle" - POSTGRES = "postgres" - PRESTO = "presto" - PRQL = "prql" - REDSHIFT = "redshift" - RISINGWAVE = "risingwave" - SNOWFLAKE = "snowflake" - SPARK = "spark" - SPARK2 = "spark2" - SQLITE = "sqlite" - STARROCKS = "starrocks" - TABLEAU = "tableau" - TERADATA = "teradata" - TRINO = "trino" - TSQL = "tsql" - - -class NormalizationStrategy(str, AutoName): - """Specifies the strategy according to which identifiers should be normalized.""" - - LOWERCASE = auto() - """Unquoted identifiers are lowercased.""" - - UPPERCASE = auto() - """Unquoted identifiers are uppercased.""" - - CASE_SENSITIVE = auto() - """Always case-sensitive, regardless of quotes.""" - - CASE_INSENSITIVE = auto() - """Always case-insensitive, regardless of quotes.""" - - -class Version(int): - def __new__(cls, version_str: t.Optional[str], *args, **kwargs): - if version_str: - parts = version_str.split(".") - parts.extend(["0"] * (3 - len(parts))) - v = int("".join([p.zfill(3) for p in parts])) - else: - # No version defined means we should support the latest engine semantics, so - # the comparison to any specific version should yield that latest is greater - v = sys.maxsize - - return super(Version, cls).__new__(cls, v) - - -class _Dialect(type): - _classes: t.Dict[str, t.Type[Dialect]] = {} - - def __eq__(cls, other: t.Any) -> bool: - if cls is other: - return True - if isinstance(other, str): - return cls is cls.get(other) - if isinstance(other, Dialect): - return cls is type(other) - - return False - - def __hash__(cls) -> int: - return hash(cls.__name__.lower()) - - @property - def classes(cls): - if len(DIALECT_MODULE_NAMES) != len(cls._classes): - for key in DIALECT_MODULE_NAMES: - cls._try_load(key) - - return cls._classes - - @classmethod - def _try_load(cls, key: str | Dialects) -> None: - if isinstance(key, Dialects): - key = key.value - - # This import will lead to a new dialect being loaded, and hence, registered. - # We check that the key is an actual sqlglot module to avoid blindly importing - # files. Custom user dialects need to be imported at the top-level package, in - # order for them to be registered as soon as possible. - if key in DIALECT_MODULE_NAMES: - importlib.import_module(f"sqlglot.dialects.{key}") - - @classmethod - def __getitem__(cls, key: str) -> t.Type[Dialect]: - if key not in cls._classes: - cls._try_load(key) - - return cls._classes[key] - - @classmethod - def get( - cls, key: str, default: t.Optional[t.Type[Dialect]] = None - ) -> t.Optional[t.Type[Dialect]]: - if key not in cls._classes: - cls._try_load(key) - - return cls._classes.get(key, default) - - def __new__(cls, clsname, bases, attrs): - klass = super().__new__(cls, clsname, bases, attrs) - enum = Dialects.__members__.get(clsname.upper()) - cls._classes[enum.value if enum is not None else clsname.lower()] = klass - - klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) - klass.FORMAT_TRIE = ( - new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE - ) - klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} - klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) - klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} - klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) - - klass.INVERSE_CREATABLE_KIND_MAPPING = { - v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() - } - - base = seq_get(bases, 0) - base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) - base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) - base_parser = (getattr(base, "parser_class", Parser),) - base_generator = (getattr(base, "generator_class", Generator),) - - klass.tokenizer_class = klass.__dict__.get( - "Tokenizer", type("Tokenizer", base_tokenizer, {}) - ) - klass.jsonpath_tokenizer_class = klass.__dict__.get( - "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) - ) - klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) - klass.generator_class = klass.__dict__.get( - "Generator", type("Generator", base_generator, {}) - ) - - klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] - klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( - klass.tokenizer_class._IDENTIFIERS.items() - )[0] - - def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: - return next( - ( - (s, e) - for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() - if t == token_type - ), - (None, None), - ) - - klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) - klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) - klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) - klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) - - if "\\" in klass.tokenizer_class.STRING_ESCAPES: - klass.UNESCAPED_SEQUENCES = { - **UNESCAPED_SEQUENCES, - **klass.UNESCAPED_SEQUENCES, - } - - klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} - - klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS - - if enum not in ("", "bigquery"): - klass.generator_class.SELECT_KINDS = () - - if enum not in ("", "athena", "presto", "trino", "duckdb"): - klass.generator_class.TRY_SUPPORTED = False - klass.generator_class.SUPPORTS_UESCAPE = False - - if enum not in ("", "databricks", "hive", "spark", "spark2"): - modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() - for modifier in ("cluster", "distribute", "sort"): - modifier_transforms.pop(modifier, None) - - klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms - - if enum not in ("", "doris", "mysql"): - klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { - TokenType.STRAIGHT_JOIN, - } - klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { - TokenType.STRAIGHT_JOIN, - } - - if not klass.SUPPORTS_SEMI_ANTI_JOIN: - klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { - TokenType.ANTI, - TokenType.SEMI, - } - - return klass - - -class Dialect(metaclass=_Dialect): - INDEX_OFFSET = 0 - """The base index offset for arrays.""" - - WEEK_OFFSET = 0 - """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" - - UNNEST_COLUMN_ONLY = False - """Whether `UNNEST` table aliases are treated as column aliases.""" - - ALIAS_POST_TABLESAMPLE = False - """Whether the table alias comes after tablesample.""" - - TABLESAMPLE_SIZE_IS_PERCENT = False - """Whether a size in the table sample clause represents percentage.""" - - NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE - """Specifies the strategy according to which identifiers should be normalized.""" - - IDENTIFIERS_CAN_START_WITH_DIGIT = False - """Whether an unquoted identifier can start with a digit.""" - - DPIPE_IS_STRING_CONCAT = True - """Whether the DPIPE token (`||`) is a string concatenation operator.""" - - STRICT_STRING_CONCAT = False - """Whether `CONCAT`'s arguments must be strings.""" - - SUPPORTS_USER_DEFINED_TYPES = True - """Whether user-defined data types are supported.""" - - SUPPORTS_SEMI_ANTI_JOIN = True - """Whether `SEMI` or `ANTI` joins are supported.""" - - SUPPORTS_COLUMN_JOIN_MARKS = False - """Whether the old-style outer join (+) syntax is supported.""" - - COPY_PARAMS_ARE_CSV = True - """Separator of COPY statement parameters.""" - - NORMALIZE_FUNCTIONS: bool | str = "upper" - """ - Determines how function names are going to be normalized. - Possible values: - "upper" or True: Convert names to uppercase. - "lower": Convert names to lowercase. - False: Disables function name normalization. - """ - - PRESERVE_ORIGINAL_NAMES: bool = False - """ - Whether the name of the function should be preserved inside the node's metadata, - can be useful for roundtripping deprecated vs new functions that share an AST node - e.g JSON_VALUE vs JSON_EXTRACT_SCALAR in BigQuery - """ - - LOG_BASE_FIRST: t.Optional[bool] = True - """ - Whether the base comes first in the `LOG` function. - Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) - """ - - NULL_ORDERING = "nulls_are_small" - """ - Default `NULL` ordering method to use if not explicitly set. - Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` - """ - - TYPED_DIVISION = False - """ - Whether the behavior of `a / b` depends on the types of `a` and `b`. - False means `a / b` is always float division. - True means `a / b` is integer division if both `a` and `b` are integers. - """ - - SAFE_DIVISION = False - """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" - - CONCAT_COALESCE = False - """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" - - HEX_LOWERCASE = False - """Whether the `HEX` function returns a lowercase hexadecimal string.""" - - DATE_FORMAT = "'%Y-%m-%d'" - DATEINT_FORMAT = "'%Y%m%d'" - TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" - - TIME_MAPPING: t.Dict[str, str] = {} - """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" - - # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time - # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE - FORMAT_MAPPING: t.Dict[str, str] = {} - """ - Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. - If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. - """ - - UNESCAPED_SEQUENCES: t.Dict[str, str] = {} - """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" - - PSEUDOCOLUMNS: t.Set[str] = set() - """ - Columns that are auto-generated by the engine corresponding to this dialect. - For example, such columns may be excluded from `SELECT *` queries. - """ - - PREFER_CTE_ALIAS_COLUMN = False - """ - Some dialects, such as Snowflake, allow you to reference a CTE column alias in the - HAVING clause of the CTE. This flag will cause the CTE alias columns to override - any projection aliases in the subquery. - - For example, - WITH y(c) AS ( - SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 - ) SELECT c FROM y; - - will be rewritten as - - WITH y(c) AS ( - SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 - ) SELECT c FROM y; - """ - - COPY_PARAMS_ARE_CSV = True - """ - Whether COPY statement parameters are separated by comma or whitespace - """ - - FORCE_EARLY_ALIAS_REF_EXPANSION = False - """ - Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). - - For example: - WITH data AS ( - SELECT - 1 AS id, - 2 AS my_id - ) - SELECT - id AS my_id - FROM - data - WHERE - my_id = 1 - GROUP BY - my_id, - HAVING - my_id = 1 - - In most dialects, "my_id" would refer to "data.my_id" across the query, except: - - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e - it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - - Clickhouse, which will forward the alias across the query i.e it resolves - to "WHERE id = 1 GROUP BY id HAVING id = 1" - """ - - EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False - """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" - - SUPPORTS_ORDER_BY_ALL = False - """ - Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks - """ - - HAS_DISTINCT_ARRAY_CONSTRUCTORS = False - """ - Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) - as the former is of type INT[] vs the latter which is SUPER - """ - - SUPPORTS_FIXED_SIZE_ARRAYS = False - """ - Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. - in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should - be interpreted as a subscript/index operator. - """ - - STRICT_JSON_PATH_SYNTAX = True - """Whether failing to parse a JSON path expression using the JSONPath dialect will log a warning.""" - - ON_CONDITION_EMPTY_BEFORE_ERROR = True - """Whether "X ON EMPTY" should come before "X ON ERROR" (for dialects like T-SQL, MySQL, Oracle).""" - - ARRAY_AGG_INCLUDES_NULLS: t.Optional[bool] = True - """Whether ArrayAgg needs to filter NULL values.""" - - PROMOTE_TO_INFERRED_DATETIME_TYPE = False - """ - This flag is used in the optimizer's canonicalize rule and determines whether x will be promoted - to the literal's type in x::DATE < '2020-01-01 12:05:03' (i.e., DATETIME). When false, the literal - is cast to x's type to match it instead. - """ - - SUPPORTS_VALUES_DEFAULT = True - """Whether the DEFAULT keyword is supported in the VALUES clause.""" - - NUMBERS_CAN_BE_UNDERSCORE_SEPARATED = False - """Whether number literals can include underscores for better readability""" - - HEX_STRING_IS_INTEGER_TYPE: bool = False - """Whether hex strings such as x'CC' evaluate to integer or binary/blob type""" - - REGEXP_EXTRACT_DEFAULT_GROUP = 0 - """The default value for the capturing group.""" - - SET_OP_DISTINCT_BY_DEFAULT: t.Dict[t.Type[exp.Expression], t.Optional[bool]] = { - exp.Except: True, - exp.Intersect: True, - exp.Union: True, - } - """ - Whether a set operation uses DISTINCT by default. This is `None` when either `DISTINCT` or `ALL` - must be explicitly specified. - """ - - CREATABLE_KIND_MAPPING: dict[str, str] = {} - """ - Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse - equivalent of CREATE SCHEMA is CREATE DATABASE. - """ - - # Whether ADD is present for each column added by ALTER TABLE - ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = True - - # --- Autofilled --- - - tokenizer_class = Tokenizer - jsonpath_tokenizer_class = JSONPathTokenizer - parser_class = Parser - generator_class = Generator - - # A trie of the time_mapping keys - TIME_TRIE: t.Dict = {} - FORMAT_TRIE: t.Dict = {} - - INVERSE_TIME_MAPPING: t.Dict[str, str] = {} - INVERSE_TIME_TRIE: t.Dict = {} - INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} - INVERSE_FORMAT_TRIE: t.Dict = {} - - INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} - - ESCAPED_SEQUENCES: t.Dict[str, str] = {} - - # Delimiters for string literals and identifiers - QUOTE_START = "'" - QUOTE_END = "'" - IDENTIFIER_START = '"' - IDENTIFIER_END = '"' - - # Delimiters for bit, hex, byte and unicode literals - BIT_START: t.Optional[str] = None - BIT_END: t.Optional[str] = None - HEX_START: t.Optional[str] = None - HEX_END: t.Optional[str] = None - BYTE_START: t.Optional[str] = None - BYTE_END: t.Optional[str] = None - UNICODE_START: t.Optional[str] = None - UNICODE_END: t.Optional[str] = None - - DATE_PART_MAPPING = { - "Y": "YEAR", - "YY": "YEAR", - "YYY": "YEAR", - "YYYY": "YEAR", - "YR": "YEAR", - "YEARS": "YEAR", - "YRS": "YEAR", - "MM": "MONTH", - "MON": "MONTH", - "MONS": "MONTH", - "MONTHS": "MONTH", - "D": "DAY", - "DD": "DAY", - "DAYS": "DAY", - "DAYOFMONTH": "DAY", - "DAY OF WEEK": "DAYOFWEEK", - "WEEKDAY": "DAYOFWEEK", - "DOW": "DAYOFWEEK", - "DW": "DAYOFWEEK", - "WEEKDAY_ISO": "DAYOFWEEKISO", - "DOW_ISO": "DAYOFWEEKISO", - "DW_ISO": "DAYOFWEEKISO", - "DAY OF YEAR": "DAYOFYEAR", - "DOY": "DAYOFYEAR", - "DY": "DAYOFYEAR", - "W": "WEEK", - "WK": "WEEK", - "WEEKOFYEAR": "WEEK", - "WOY": "WEEK", - "WY": "WEEK", - "WEEK_ISO": "WEEKISO", - "WEEKOFYEARISO": "WEEKISO", - "WEEKOFYEAR_ISO": "WEEKISO", - "Q": "QUARTER", - "QTR": "QUARTER", - "QTRS": "QUARTER", - "QUARTERS": "QUARTER", - "H": "HOUR", - "HH": "HOUR", - "HR": "HOUR", - "HOURS": "HOUR", - "HRS": "HOUR", - "M": "MINUTE", - "MI": "MINUTE", - "MIN": "MINUTE", - "MINUTES": "MINUTE", - "MINS": "MINUTE", - "S": "SECOND", - "SEC": "SECOND", - "SECONDS": "SECOND", - "SECS": "SECOND", - "MS": "MILLISECOND", - "MSEC": "MILLISECOND", - "MSECS": "MILLISECOND", - "MSECOND": "MILLISECOND", - "MSECONDS": "MILLISECOND", - "MILLISEC": "MILLISECOND", - "MILLISECS": "MILLISECOND", - "MILLISECON": "MILLISECOND", - "MILLISECONDS": "MILLISECOND", - "US": "MICROSECOND", - "USEC": "MICROSECOND", - "USECS": "MICROSECOND", - "MICROSEC": "MICROSECOND", - "MICROSECS": "MICROSECOND", - "USECOND": "MICROSECOND", - "USECONDS": "MICROSECOND", - "MICROSECONDS": "MICROSECOND", - "NS": "NANOSECOND", - "NSEC": "NANOSECOND", - "NANOSEC": "NANOSECOND", - "NSECOND": "NANOSECOND", - "NSECONDS": "NANOSECOND", - "NANOSECS": "NANOSECOND", - "EPOCH_SECOND": "EPOCH", - "EPOCH_SECONDS": "EPOCH", - "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", - "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", - "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", - "TZH": "TIMEZONE_HOUR", - "TZM": "TIMEZONE_MINUTE", - "DEC": "DECADE", - "DECS": "DECADE", - "DECADES": "DECADE", - "MIL": "MILLENIUM", - "MILS": "MILLENIUM", - "MILLENIA": "MILLENIUM", - "C": "CENTURY", - "CENT": "CENTURY", - "CENTS": "CENTURY", - "CENTURIES": "CENTURY", - } - - TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { - exp.DataType.Type.BIGINT: { - exp.ApproxDistinct, - exp.ArraySize, - exp.Length, - }, - exp.DataType.Type.BOOLEAN: { - exp.Between, - exp.Boolean, - exp.In, - exp.RegexpLike, - }, - exp.DataType.Type.DATE: { - exp.CurrentDate, - exp.Date, - exp.DateFromParts, - exp.DateStrToDate, - exp.DiToDate, - exp.StrToDate, - exp.TimeStrToDate, - exp.TsOrDsToDate, - }, - exp.DataType.Type.DATETIME: { - exp.CurrentDatetime, - exp.Datetime, - exp.DatetimeAdd, - exp.DatetimeSub, - }, - exp.DataType.Type.DOUBLE: { - exp.ApproxQuantile, - exp.Avg, - exp.Exp, - exp.Ln, - exp.Log, - exp.Pow, - exp.Quantile, - exp.Round, - exp.SafeDivide, - exp.Sqrt, - exp.Stddev, - exp.StddevPop, - exp.StddevSamp, - exp.ToDouble, - exp.Variance, - exp.VariancePop, - }, - exp.DataType.Type.INT: { - exp.Ceil, - exp.DatetimeDiff, - exp.DateDiff, - exp.TimestampDiff, - exp.TimeDiff, - exp.DateToDi, - exp.Levenshtein, - exp.Sign, - exp.StrPosition, - exp.TsOrDiToDi, - }, - exp.DataType.Type.JSON: { - exp.ParseJSON, - }, - exp.DataType.Type.TIME: { - exp.CurrentTime, - exp.Time, - exp.TimeAdd, - exp.TimeSub, - }, - exp.DataType.Type.TIMESTAMP: { - exp.CurrentTimestamp, - exp.StrToTime, - exp.TimeStrToTime, - exp.TimestampAdd, - exp.TimestampSub, - exp.UnixToTime, - }, - exp.DataType.Type.TINYINT: { - exp.Day, - exp.Month, - exp.Week, - exp.Year, - exp.Quarter, - }, - exp.DataType.Type.VARCHAR: { - exp.ArrayConcat, - exp.Concat, - exp.ConcatWs, - exp.DateToDateStr, - exp.DPipe, - exp.GroupConcat, - exp.Initcap, - exp.Lower, - exp.Substring, - exp.String, - exp.TimeToStr, - exp.TimeToTimeStr, - exp.Trim, - exp.TsOrDsToDateStr, - exp.UnixToStr, - exp.UnixToTimeStr, - exp.Upper, - }, - } - - ANNOTATORS: AnnotatorsType = { - **{ - expr_type: lambda self, e: self._annotate_unary(e) - for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) - }, - **{ - expr_type: lambda self, e: self._annotate_binary(e) - for expr_type in subclasses(exp.__name__, exp.Binary) - }, - **{ - expr_type: annotate_with_type_lambda(data_type) - for data_type, expressions in TYPE_TO_EXPRESSIONS.items() - for expr_type in expressions - }, - exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), - exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), - exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), - exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), - exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), - exp.Bracket: lambda self, e: self._annotate_bracket(e), - exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), - exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), - exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), - exp.Count: lambda self, e: self._annotate_with_type( - e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT - ), - exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), - exp.DateAdd: lambda self, e: self._annotate_timeunit(e), - exp.DateSub: lambda self, e: self._annotate_timeunit(e), - exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), - exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), - exp.Div: lambda self, e: self._annotate_div(e), - exp.Dot: lambda self, e: self._annotate_dot(e), - exp.Explode: lambda self, e: self._annotate_explode(e), - exp.Extract: lambda self, e: self._annotate_extract(e), - exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), - exp.GenerateDateArray: lambda self, e: self._annotate_with_type( - e, exp.DataType.build("ARRAY") - ), - exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( - e, exp.DataType.build("ARRAY") - ), - exp.Greatest: lambda self, e: self._annotate_by_args(e, "this", "expressions"), - exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), - exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), - exp.Least: lambda self, e: self._annotate_by_args(e, "this", "expressions"), - exp.Literal: lambda self, e: self._annotate_literal(e), - exp.Map: lambda self, e: self._annotate_map(e), - exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), - exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), - exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), - exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), - exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), - exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), - exp.Struct: lambda self, e: self._annotate_struct(e), - exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), - exp.SortArray: lambda self, e: self._annotate_by_args(e, "this"), - exp.Timestamp: lambda self, e: self._annotate_with_type( - e, - exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, - ), - exp.ToMap: lambda self, e: self._annotate_to_map(e), - exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), - exp.Unnest: lambda self, e: self._annotate_unnest(e), - exp.VarMap: lambda self, e: self._annotate_map(e), - } - - # Specifies what types a given type can be coerced into - COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} - - # Determines the supported Dialect instance settings - SUPPORTED_SETTINGS = { - "normalization_strategy", - "version", - } - - @classmethod - def get_or_raise(cls, dialect: DialectType) -> Dialect: - """ - Look up a dialect in the global dialect registry and return it if it exists. - - Args: - dialect: The target dialect. If this is a string, it can be optionally followed by - additional key-value pairs that are separated by commas and are used to specify - dialect settings, such as whether the dialect's identifiers are case-sensitive. - - Example: - >>> dialect = dialect_class = get_or_raise("duckdb") - >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") - - Returns: - The corresponding Dialect instance. - """ - - if not dialect: - return cls() - if isinstance(dialect, _Dialect): - return dialect() - if isinstance(dialect, Dialect): - return dialect - if isinstance(dialect, str): - try: - dialect_name, *kv_strings = dialect.split(",") - kv_pairs = (kv.split("=") for kv in kv_strings) - kwargs = {} - for pair in kv_pairs: - key = pair[0].strip() - value: t.Union[bool | str | None] = None - - if len(pair) == 1: - # Default initialize standalone settings to True - value = True - elif len(pair) == 2: - value = pair[1].strip() - - kwargs[key] = to_bool(value) - - except ValueError: - raise ValueError( - f"Invalid dialect format: '{dialect}'. " - "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." - ) - - result = cls.get(dialect_name.strip()) - if not result: - suggest_closest_match_and_fail("dialect", dialect_name, list(DIALECT_MODULE_NAMES)) - - assert result is not None - return result(**kwargs) - - raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") - - @classmethod - def format_time( - cls, expression: t.Optional[str | exp.Expression] - ) -> t.Optional[exp.Expression]: - """Converts a time format in this dialect to its equivalent Python `strftime` format.""" - if isinstance(expression, str): - return exp.Literal.string( - # the time formats are quoted - format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) - ) - - if expression and expression.is_string: - return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) - - return expression - - def __init__(self, **kwargs) -> None: - self.version = Version(kwargs.pop("version", None)) - - normalization_strategy = kwargs.pop("normalization_strategy", None) - if normalization_strategy is None: - self.normalization_strategy = self.NORMALIZATION_STRATEGY - else: - self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) - - self.settings = kwargs - - for unsupported_setting in kwargs.keys() - self.SUPPORTED_SETTINGS: - suggest_closest_match_and_fail("setting", unsupported_setting, self.SUPPORTED_SETTINGS) - - def __eq__(self, other: t.Any) -> bool: - # Does not currently take dialect state into account - return type(self) == other - - def __hash__(self) -> int: - # Does not currently take dialect state into account - return hash(type(self)) - - def normalize_identifier(self, expression: E) -> E: - """ - Transforms an identifier in a way that resembles how it'd be resolved by this dialect. - - For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it - lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so - it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, - and so any normalization would be prohibited in order to avoid "breaking" the identifier. - - There are also dialects like Spark, which are case-insensitive even when quotes are - present, and dialects like MySQL, whose resolution rules match those employed by the - underlying operating system, for example they may always be case-sensitive in Linux. - - Finally, the normalization behavior of some engines can even be controlled through flags, - like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. - - SQLGlot aims to understand and handle all of these different behaviors gracefully, so - that it can analyze queries in the optimizer and successfully capture their semantics. - """ - if ( - isinstance(expression, exp.Identifier) - and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE - and ( - not expression.quoted - or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE - ) - ): - expression.set( - "this", - ( - expression.this.upper() - if self.normalization_strategy is NormalizationStrategy.UPPERCASE - else expression.this.lower() - ), - ) - - return expression - - def case_sensitive(self, text: str) -> bool: - """Checks if text contains any case sensitive characters, based on the dialect's rules.""" - if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: - return False - - unsafe = ( - str.islower - if self.normalization_strategy is NormalizationStrategy.UPPERCASE - else str.isupper - ) - return any(unsafe(char) for char in text) - - def can_identify(self, text: str, identify: str | bool = "safe") -> bool: - """Checks if text can be identified given an identify option. - - Args: - text: The text to check. - identify: - `"always"` or `True`: Always returns `True`. - `"safe"`: Only returns `True` if the identifier is case-insensitive. - - Returns: - Whether the given text can be identified. - """ - if identify is True or identify == "always": - return True - - if identify == "safe": - return not self.case_sensitive(text) - - return False - - def quote_identifier(self, expression: E, identify: bool = True) -> E: - """ - Adds quotes to a given identifier. - - Args: - expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. - identify: If set to `False`, the quotes will only be added if the identifier is deemed - "unsafe", with respect to its characters and this dialect's normalization strategy. - """ - if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): - name = expression.this - expression.set( - "quoted", - identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), - ) - - return expression - - def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - if isinstance(path, exp.Literal): - path_text = path.name - if path.is_number: - path_text = f"[{path_text}]" - try: - return parse_json_path(path_text, self) - except ParseError as e: - if self.STRICT_JSON_PATH_SYNTAX: - logger.warning(f"Invalid JSON path syntax. {str(e)}") - - return path - - def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: - return self.parser(**opts).parse(self.tokenize(sql), sql) - - def parse_into( - self, expression_type: exp.IntoType, sql: str, **opts - ) -> t.List[t.Optional[exp.Expression]]: - return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) - - def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: - return self.generator(**opts).generate(expression, copy=copy) - - def transpile(self, sql: str, **opts) -> t.List[str]: - return [ - self.generate(expression, copy=False, **opts) if expression else "" - for expression in self.parse(sql) - ] - - def tokenize(self, sql: str) -> t.List[Token]: - return self.tokenizer.tokenize(sql) - - @property - def tokenizer(self) -> Tokenizer: - return self.tokenizer_class(dialect=self) - - @property - def jsonpath_tokenizer(self) -> JSONPathTokenizer: - return self.jsonpath_tokenizer_class(dialect=self) - - def parser(self, **opts) -> Parser: - return self.parser_class(dialect=self, **opts) - - def generator(self, **opts) -> Generator: - return self.generator_class(dialect=self, **opts) - - def generate_values_aliases(self, expression: exp.Values) -> t.List[exp.Identifier]: - return [ - exp.to_identifier(f"_col_{i}") - for i, _ in enumerate(expression.expressions[0].expressions) - ] - - -DialectType = t.Union[str, Dialect, t.Type[Dialect], None] - - -def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: - return lambda self, expression: self.func(name, *flatten(expression.args.values())) - - -@unsupported_args("accuracy") -def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: - return self.func("APPROX_COUNT_DISTINCT", expression.this) - - -def if_sql( - name: str = "IF", false_value: t.Optional[exp.Expression | str] = None -) -> t.Callable[[Generator, exp.If], str]: - def _if_sql(self: Generator, expression: exp.If) -> str: - return self.func( - name, - expression.this, - expression.args.get("true"), - expression.args.get("false") or false_value, - ) - - return _if_sql - - -def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: - this = expression.this - if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: - this.replace(exp.cast(this, exp.DataType.Type.JSON)) - - return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") - - -def inline_array_sql(self: Generator, expression: exp.Array) -> str: - return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" - - -def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: - elem = seq_get(expression.expressions, 0) - if isinstance(elem, exp.Expression) and elem.find(exp.Query): - return self.func("ARRAY", elem) - return inline_array_sql(self, expression) - - -def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: - return self.like_sql( - exp.Like( - this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) - ) - ) - - -def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: - zone = self.sql(expression, "this") - return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" - - -def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: - if expression.args.get("recursive"): - self.unsupported("Recursive CTEs are unsupported") - expression.args["recursive"] = False - return self.with_sql(expression) - - -def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: - self.unsupported("TABLESAMPLE unsupported") - return self.sql(expression.this) - - -def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: - self.unsupported("PIVOT unsupported") - return "" - - -def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: - return self.cast_sql(expression) - - -def no_comment_column_constraint_sql( - self: Generator, expression: exp.CommentColumnConstraint -) -> str: - self.unsupported("CommentColumnConstraint unsupported") - return "" - - -def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: - self.unsupported("MAP_FROM_ENTRIES unsupported") - return "" - - -def property_sql(self: Generator, expression: exp.Property) -> str: - return f"{self.property_name(expression, string_key=True)}={self.sql(expression, 'value')}" - - -def strposition_sql( - self: Generator, - expression: exp.StrPosition, - func_name: str = "STRPOS", - supports_position: bool = False, - supports_occurrence: bool = False, - use_ansi_position: bool = True, -) -> str: - string = expression.this - substr = expression.args.get("substr") - position = expression.args.get("position") - occurrence = expression.args.get("occurrence") - zero = exp.Literal.number(0) - one = exp.Literal.number(1) - - if supports_occurrence and occurrence and supports_position and not position: - position = one - - transpile_position = position and not supports_position - if transpile_position: - string = exp.Substring(this=string, start=position) - - if func_name == "POSITION" and use_ansi_position: - func = exp.Anonymous(this=func_name, expressions=[exp.In(this=substr, field=string)]) - else: - args = [substr, string] if func_name in ("LOCATE", "CHARINDEX") else [string, substr] - if supports_position: - args.append(position) - if occurrence: - if supports_occurrence: - args.append(occurrence) - else: - self.unsupported(f"{func_name} does not support the occurrence parameter.") - func = exp.Anonymous(this=func_name, expressions=args) - - if transpile_position: - func_with_offset = exp.Sub(this=func + position, expression=one) - func_wrapped = exp.If(this=func.eq(zero), true=zero, false=func_with_offset) - return self.sql(func_wrapped) - - return self.sql(func) - - -def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: - return ( - f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" - ) - - -def var_map_sql( - self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" -) -> str: - keys = expression.args.get("keys") - values = expression.args.get("values") - - if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): - self.unsupported("Cannot convert array columns into map.") - return self.func(map_func_name, keys, values) - - args = [] - for key, value in zip(keys.expressions, values.expressions): - args.append(self.sql(key)) - args.append(self.sql(value)) - - return self.func(map_func_name, *args) - - -def build_formatted_time( - exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None -) -> t.Callable[[t.List], E]: - """Helper used for time expressions. - - Args: - exp_class: the expression class to instantiate. - dialect: target sql dialect. - default: the default format, True being time. - - Returns: - A callable that can be used to return the appropriately formatted time expression. - """ - - def _builder(args: t.List): - return exp_class( - this=seq_get(args, 0), - format=Dialect[dialect].format_time( - seq_get(args, 1) - or (Dialect[dialect].TIME_FORMAT if default is True else default or None) - ), - ) - - return _builder - - -def time_format( - dialect: DialectType = None, -) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: - def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: - """ - Returns the time format for a given expression, unless it's equivalent - to the default time format of the dialect of interest. - """ - time_format = self.format_time(expression) - return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None - - return _time_format - - -def build_date_delta( - exp_class: t.Type[E], - unit_mapping: t.Optional[t.Dict[str, str]] = None, - default_unit: t.Optional[str] = "DAY", - supports_timezone: bool = False, -) -> t.Callable[[t.List], E]: - def _builder(args: t.List) -> E: - unit_based = len(args) >= 3 - has_timezone = len(args) == 4 - this = args[2] if unit_based else seq_get(args, 0) - unit = None - if unit_based or default_unit: - unit = args[0] if unit_based else exp.Literal.string(default_unit) - unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit - expression = exp_class(this=this, expression=seq_get(args, 1), unit=unit) - if supports_timezone and has_timezone: - expression.set("zone", args[-1]) - return expression - - return _builder - - -def build_date_delta_with_interval( - expression_class: t.Type[E], -) -> t.Callable[[t.List], t.Optional[E]]: - def _builder(args: t.List) -> t.Optional[E]: - if len(args) < 2: - return None - - interval = args[1] - - if not isinstance(interval, exp.Interval): - raise ParseError(f"INTERVAL expression expected but got '{interval}'") - - return expression_class(this=args[0], expression=interval.this, unit=unit_to_str(interval)) - - return _builder - - -def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: - unit = seq_get(args, 0) - this = seq_get(args, 1) - - if isinstance(this, exp.Cast) and this.is_type("date"): - return exp.DateTrunc(unit=unit, this=this) - return exp.TimestampTrunc(this=this, unit=unit) - - -def date_add_interval_sql( - data_type: str, kind: str -) -> t.Callable[[Generator, exp.Expression], str]: - def func(self: Generator, expression: exp.Expression) -> str: - this = self.sql(expression, "this") - interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) - return f"{data_type}_{kind}({this}, {self.sql(interval)})" - - return func - - -def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: - def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: - args = [unit_to_str(expression), expression.this] - if zone: - args.append(expression.args.get("zone")) - return self.func("DATE_TRUNC", *args) - - return _timestamptrunc_sql - - -def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: - zone = expression.args.get("zone") - if not zone: - from sqlglot.optimizer.annotate_types import annotate_types - - target_type = ( - annotate_types(expression, dialect=self.dialect).type or exp.DataType.Type.TIMESTAMP - ) - return self.sql(exp.cast(expression.this, target_type)) - if zone.name.lower() in TIMEZONES: - return self.sql( - exp.AtTimeZone( - this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), - zone=zone, - ) - ) - return self.func("TIMESTAMP", expression.this, zone) - - -def no_time_sql(self: Generator, expression: exp.Time) -> str: - # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ AT TIME ZONE AS TIME) - this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) - expr = exp.cast( - exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME - ) - return self.sql(expr) - - -def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: - this = expression.this - expr = expression.expression - - if expr.name.lower() in TIMEZONES: - # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ AT TIME ZONE AS TIMESTAMP) - this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) - this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) - return self.sql(this) - - this = exp.cast(this, exp.DataType.Type.DATE) - expr = exp.cast(expr, exp.DataType.Type.TIME) - - return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) - - -def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: - return self.sql( - exp.Substring( - this=expression.this, start=exp.Literal.number(1), length=expression.expression - ) - ) - - -def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: - return self.sql( - exp.Substring( - this=expression.this, - start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), - ) - ) - - -def timestrtotime_sql( - self: Generator, - expression: exp.TimeStrToTime, - include_precision: bool = False, -) -> str: - datatype = exp.DataType.build( - exp.DataType.Type.TIMESTAMPTZ - if expression.args.get("zone") - else exp.DataType.Type.TIMESTAMP - ) - - if isinstance(expression.this, exp.Literal) and include_precision: - precision = subsecond_precision(expression.this.name) - if precision > 0: - datatype = exp.DataType.build( - datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] - ) - - return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) - - -def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: - return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) - - -# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 -def encode_decode_sql( - self: Generator, expression: exp.Expression, name: str, replace: bool = True -) -> str: - charset = expression.args.get("charset") - if charset and charset.name.lower() != "utf-8": - self.unsupported(f"Expected utf-8 character set, got {charset}.") - - return self.func(name, expression.this, expression.args.get("replace") if replace else None) - - -def min_or_least(self: Generator, expression: exp.Min) -> str: - name = "LEAST" if expression.expressions else "MIN" - return rename_func(name)(self, expression) - - -def max_or_greatest(self: Generator, expression: exp.Max) -> str: - name = "GREATEST" if expression.expressions else "MAX" - return rename_func(name)(self, expression) - - -def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: - cond = expression.this - - if isinstance(expression.this, exp.Distinct): - cond = expression.this.expressions[0] - self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") - - return self.func("sum", exp.func("if", cond, 1, 0)) - - -def trim_sql(self: Generator, expression: exp.Trim, default_trim_type: str = "") -> str: - target = self.sql(expression, "this") - trim_type = self.sql(expression, "position") or default_trim_type - remove_chars = self.sql(expression, "expression") - collation = self.sql(expression, "collation") - - # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific - if not remove_chars: - return self.trim_sql(expression) - - trim_type = f"{trim_type} " if trim_type else "" - remove_chars = f"{remove_chars} " if remove_chars else "" - from_part = "FROM " if trim_type or remove_chars else "" - collation = f" COLLATE {collation}" if collation else "" - return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" - - -def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: - return self.func("STRPTIME", expression.this, self.format_time(expression)) - - -def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: - return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) - - -def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: - delim, *rest_args = expression.expressions - return self.sql( - reduce( - lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), - rest_args, - ) - ) - - -@unsupported_args("position", "occurrence", "parameters") -def regexp_extract_sql( - self: Generator, expression: exp.RegexpExtract | exp.RegexpExtractAll -) -> str: - group = expression.args.get("group") - - # Do not render group if it's the default value for this dialect - if group and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP): - group = None - - return self.func(expression.sql_name(), expression.this, expression.expression, group) - - -@unsupported_args("position", "occurrence", "modifiers") -def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: - return self.func( - "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] - ) - - -def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: - names = [] - for agg in aggregations: - if isinstance(agg, exp.Alias): - names.append(agg.alias) - else: - """ - This case corresponds to aggregations without aliases being used as suffixes - (e.g. col_avg(foo)). We need to unquote identifiers because they're going to - be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. - Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). - """ - agg_all_unquoted = agg.transform( - lambda node: ( - exp.Identifier(this=node.name, quoted=False) - if isinstance(node, exp.Identifier) - else node - ) - ) - names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) - - return names - - -def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: - return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) - - -# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects -def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: - return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) - - -def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: - return self.func("MAX", expression.this) - - -def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: - a = self.sql(expression.left) - b = self.sql(expression.right) - return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" - - -def is_parse_json(expression: exp.Expression) -> bool: - return isinstance(expression, exp.ParseJSON) or ( - isinstance(expression, exp.Cast) and expression.is_type("json") - ) - - -def isnull_to_is_null(args: t.List) -> exp.Expression: - return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) - - -def generatedasidentitycolumnconstraint_sql( - self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint -) -> str: - start = self.sql(expression, "start") or "1" - increment = self.sql(expression, "increment") or "1" - return f"IDENTITY({start}, {increment})" - - -def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: - @unsupported_args("count") - def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: - return self.func(name, expression.this, expression.expression) - - return _arg_max_or_min_sql - - -def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: - this = expression.this.copy() - - return_type = expression.return_type - if return_type.is_type(exp.DataType.Type.DATE): - # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we - # can truncate timestamp strings, because some dialects can't cast them to DATE - this = exp.cast(this, exp.DataType.Type.TIMESTAMP) - - expression.this.replace(exp.cast(this, return_type)) - return expression - - -def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: - def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: - if cast and isinstance(expression, exp.TsOrDsAdd): - expression = ts_or_ds_add_cast(expression) - - return self.func( - name, - unit_to_var(expression), - expression.expression, - expression.this, - ) - - return _delta_sql - - -def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: - unit = expression.args.get("unit") - - if isinstance(unit, exp.Placeholder): - return unit - if unit: - return exp.Literal.string(unit.name) - return exp.Literal.string(default) if default else None - - -def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: - unit = expression.args.get("unit") - - if isinstance(unit, (exp.Var, exp.Placeholder)): - return unit - return exp.Var(this=default) if default else None - - -@t.overload -def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: - pass - - -@t.overload -def map_date_part( - part: t.Optional[exp.Expression], dialect: DialectType = Dialect -) -> t.Optional[exp.Expression]: - pass - - -def map_date_part(part, dialect: DialectType = Dialect): - mapped = ( - Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None - ) - return exp.var(mapped) if mapped else part - - -def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: - trunc_curr_date = exp.func("date_trunc", "month", expression.this) - plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") - minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") - - return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) - - -def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: - """Remove table refs from columns in when statements.""" - alias = expression.this.args.get("alias") - - def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: - return self.dialect.normalize_identifier(identifier).name if identifier else None - - targets = {normalize(expression.this.this)} - - if alias: - targets.add(normalize(alias.this)) - - for when in expression.args["whens"].expressions: - # only remove the target table names from certain parts of WHEN MATCHED / WHEN NOT MATCHED - # they are still valid in the , the right hand side of each UPDATE and the VALUES part - # (not the column list) of the INSERT - then: exp.Insert | exp.Update | None = when.args.get("then") - if then: - if isinstance(then, exp.Update): - for equals in then.find_all(exp.EQ): - equal_lhs = equals.this - if ( - isinstance(equal_lhs, exp.Column) - and normalize(equal_lhs.args.get("table")) in targets - ): - equal_lhs.replace(exp.column(equal_lhs.this)) - if isinstance(then, exp.Insert): - column_list = then.this - if isinstance(column_list, exp.Tuple): - for column in column_list.expressions: - if normalize(column.args.get("table")) in targets: - column.replace(exp.column(column.this)) - - return self.merge_sql(expression) - - -def build_json_extract_path( - expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False -) -> t.Callable[[t.List], F]: - def _builder(args: t.List) -> F: - segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] - for arg in args[1:]: - if not isinstance(arg, exp.Literal): - # We use the fallback parser because we can't really transpile non-literals safely - return expr_type.from_arg_list(args) - - text = arg.name - if is_int(text) and (not arrow_req_json_type or not arg.is_string): - index = int(text) - segments.append( - exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) - ) - else: - segments.append(exp.JSONPathKey(this=text)) - - # This is done to avoid failing in the expression validator due to the arg count - del args[2:] - return expr_type( - this=seq_get(args, 0), - expression=exp.JSONPath(expressions=segments), - only_json_types=arrow_req_json_type, - ) - - return _builder - - -def json_extract_segments( - name: str, quoted_index: bool = True, op: t.Optional[str] = None -) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: - def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: - path = expression.expression - if not isinstance(path, exp.JSONPath): - return rename_func(name)(self, expression) - - escape = path.args.get("escape") - - segments = [] - for segment in path.expressions: - path = self.sql(segment) - if path: - if isinstance(segment, exp.JSONPathPart) and ( - quoted_index or not isinstance(segment, exp.JSONPathSubscript) - ): - if escape: - path = self.escape_str(path) - - path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" - - segments.append(path) - - if op: - return f" {op} ".join([self.sql(expression.this), *segments]) - return self.func(name, expression.this, *segments) - - return _json_extract_segments - - -def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: - if isinstance(expression.this, exp.JSONPathWildcard): - self.unsupported("Unsupported wildcard in JSONPathKey expression") - - return expression.name - - -def filter_array_using_unnest( - self: Generator, expression: exp.ArrayFilter | exp.ArrayRemove -) -> str: - cond = expression.expression - if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: - alias = cond.expressions[0] - cond = cond.this - elif isinstance(cond, exp.Predicate): - alias = "_u" - elif isinstance(expression, exp.ArrayRemove): - alias = "_u" - cond = exp.NEQ(this=alias, expression=expression.expression) - else: - self.unsupported("Unsupported filter condition") - return "" - - unnest = exp.Unnest(expressions=[expression.this]) - filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) - return self.sql(exp.Array(expressions=[filtered])) - - -def remove_from_array_using_filter(self: Generator, expression: exp.ArrayRemove) -> str: - lambda_id = exp.to_identifier("_u") - cond = exp.NEQ(this=lambda_id, expression=expression.expression) - return self.sql( - exp.ArrayFilter( - this=expression.this, expression=exp.Lambda(this=cond, expressions=[lambda_id]) - ) - ) - - -def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: - return self.func( - "TO_NUMBER", - expression.this, - expression.args.get("format"), - expression.args.get("nlsparam"), - ) - - -def build_default_decimal_type( - precision: t.Optional[int] = None, scale: t.Optional[int] = None -) -> t.Callable[[exp.DataType], exp.DataType]: - def _builder(dtype: exp.DataType) -> exp.DataType: - if dtype.expressions or precision is None: - return dtype - - params = f"{precision}{f', {scale}' if scale is not None else ''}" - return exp.DataType.build(f"DECIMAL({params})") - - return _builder - - -def build_timestamp_from_parts(args: t.List) -> exp.Func: - if len(args) == 2: - # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, - # so we parse this into Anonymous for now instead of introducing complexity - return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) - - return exp.TimestampFromParts.from_arg_list(args) - - -def sha256_sql(self: Generator, expression: exp.SHA2) -> str: - return self.func(f"SHA{expression.text('length') or '256'}", expression.this) - - -def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: - start = expression.args.get("start") - end = expression.args.get("end") - step = expression.args.get("step") - - if isinstance(start, exp.Cast): - target_type = start.to - elif isinstance(end, exp.Cast): - target_type = end.to - else: - target_type = None - - if start and end and target_type and target_type.is_type("date", "timestamp"): - if isinstance(start, exp.Cast) and target_type is start.to: - end = exp.cast(end, target_type) - else: - start = exp.cast(start, target_type) - - return self.func("SEQUENCE", start, end, step) - - -def build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: - def _builder(args: t.List, dialect: Dialect) -> E: - return expr_type( - this=seq_get(args, 0), - expression=seq_get(args, 1), - group=seq_get(args, 2) or exp.Literal.number(dialect.REGEXP_EXTRACT_DEFAULT_GROUP), - parameters=seq_get(args, 3), - ) - - return _builder - - -def explode_to_unnest_sql(self: Generator, expression: exp.Lateral) -> str: - if isinstance(expression.this, exp.Explode): - return self.sql( - exp.Join( - this=exp.Unnest( - expressions=[expression.this.this], - alias=expression.args.get("alias"), - offset=isinstance(expression.this, exp.Posexplode), - ), - kind="cross", - ) - ) - return self.lateral_sql(expression) - - -def timestampdiff_sql(self: Generator, expression: exp.DatetimeDiff | exp.TimestampDiff) -> str: - return self.func("TIMESTAMPDIFF", expression.unit, expression.expression, expression.this) - - -def no_make_interval_sql(self: Generator, expression: exp.MakeInterval, sep: str = ", ") -> str: - args = [] - for unit, value in expression.args.items(): - if isinstance(value, exp.Kwarg): - value = value.expression - - args.append(f"{value} {unit}") - - return f"INTERVAL '{self.format_args(*args, sep=sep)}'" - - -def length_or_char_length_sql(self: Generator, expression: exp.Length) -> str: - length_func = "LENGTH" if expression.args.get("binary") else "CHAR_LENGTH" - return self.func(length_func, expression.this) - - -def groupconcat_sql( - self: Generator, - expression: exp.GroupConcat, - func_name="LISTAGG", - sep: str = ",", - within_group: bool = True, - on_overflow: bool = False, -) -> str: - this = expression.this - separator = self.sql(expression.args.get("separator") or exp.Literal.string(sep)) - - on_overflow_sql = self.sql(expression, "on_overflow") - on_overflow_sql = f" ON OVERFLOW {on_overflow_sql}" if (on_overflow and on_overflow_sql) else "" - - order = this.find(exp.Order) - - if order and order.this: - this = order.this.pop() - - args = self.format_args(this, f"{separator}{on_overflow_sql}") - listagg: exp.Expression = exp.Anonymous(this=func_name, expressions=[args]) - - if order: - if within_group: - listagg = exp.WithinGroup(this=listagg, expression=order) - else: - listagg.set("expressions", [f"{args}{self.sql(expression=expression.this)}"]) - - return self.sql(listagg) - - -def build_timetostr_or_tochar(args: t.List, dialect: Dialect) -> exp.TimeToStr | exp.ToChar: - this = seq_get(args, 0) - - if this and not this.type: - from sqlglot.optimizer.annotate_types import annotate_types - - annotate_types(this, dialect=dialect) - if this.is_type(*exp.DataType.TEMPORAL_TYPES): - dialect_name = dialect.__class__.__name__.lower() - return build_formatted_time(exp.TimeToStr, dialect_name, default=True)(args) - - return exp.ToChar.from_arg_list(args) diff --git a/altimate_packages/sqlglot/dialects/doris.py b/altimate_packages/sqlglot/dialects/doris.py deleted file mode 100644 index fabd85040..000000000 --- a/altimate_packages/sqlglot/dialects/doris.py +++ /dev/null @@ -1,561 +0,0 @@ -from __future__ import annotations - -from sqlglot import exp -from sqlglot.dialects.dialect import ( - approx_count_distinct_sql, - build_timestamp_trunc, - rename_func, - time_format, - unit_to_str, -) -from sqlglot.dialects.mysql import MySQL - - -def _lag_lead_sql(self, expression: exp.Lag | exp.Lead) -> str: - return self.func( - "LAG" if isinstance(expression, exp.Lag) else "LEAD", - expression.this, - expression.args.get("offset") or exp.Literal.number(1), - expression.args.get("default") or exp.null(), - ) - - -class Doris(MySQL): - DATE_FORMAT = "'yyyy-MM-dd'" - DATEINT_FORMAT = "'yyyyMMdd'" - TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" - - class Parser(MySQL.Parser): - FUNCTIONS = { - **MySQL.Parser.FUNCTIONS, - "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list, - "DATE_TRUNC": build_timestamp_trunc, - "MONTHS_ADD": exp.AddMonths.from_arg_list, - "REGEXP": exp.RegexpLike.from_arg_list, - "TO_DATE": exp.TsOrDsToDate.from_arg_list, - } - - FUNCTION_PARSERS = MySQL.Parser.FUNCTION_PARSERS.copy() - FUNCTION_PARSERS.pop("GROUP_CONCAT") - - class Generator(MySQL.Generator): - LAST_DAY_SUPPORTS_DATE_PART = False - VARCHAR_REQUIRES_SIZE = False - - TYPE_MAPPING = { - **MySQL.Generator.TYPE_MAPPING, - exp.DataType.Type.TEXT: "STRING", - exp.DataType.Type.TIMESTAMP: "DATETIME", - exp.DataType.Type.TIMESTAMPTZ: "DATETIME", - } - - CAST_MAPPING = {} - TIMESTAMP_FUNC_TYPES = set() - - TRANSFORMS = { - **MySQL.Generator.TRANSFORMS, - exp.AddMonths: rename_func("MONTHS_ADD"), - exp.ApproxDistinct: approx_count_distinct_sql, - exp.ArgMax: rename_func("MAX_BY"), - exp.ArgMin: rename_func("MIN_BY"), - exp.ArrayAgg: rename_func("COLLECT_LIST"), - exp.ArrayToString: rename_func("ARRAY_JOIN"), - exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), - exp.CurrentTimestamp: lambda self, _: self.func("NOW"), - exp.DateTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)), - exp.GroupConcat: lambda self, e: self.func( - "GROUP_CONCAT", e.this, e.args.get("separator") or exp.Literal.string(",") - ), - exp.JSONExtractScalar: lambda self, e: self.func("JSON_EXTRACT", e.this, e.expression), - exp.Lag: _lag_lead_sql, - exp.Lead: _lag_lead_sql, - exp.Map: rename_func("ARRAY_MAP"), - exp.RegexpLike: rename_func("REGEXP"), - exp.RegexpSplit: rename_func("SPLIT_BY_STRING"), - exp.Split: rename_func("SPLIT_BY_STRING"), - exp.StringToArray: rename_func("SPLIT_BY_STRING"), - exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)), - exp.TimeStrToDate: rename_func("TO_DATE"), - exp.TsOrDsAdd: lambda self, e: self.func("DATE_ADD", e.this, e.expression), - exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this), - exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), - exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", e.this, unit_to_str(e)), - exp.UnixToStr: lambda self, e: self.func( - "FROM_UNIXTIME", e.this, time_format("doris")(self, e) - ), - exp.UnixToTime: rename_func("FROM_UNIXTIME"), - } - - # https://github.com/apache/doris/blob/e4f41dbf1ec03f5937fdeba2ee1454a20254015b/fe/fe-core/src/main/antlr4/org/apache/doris/nereids/DorisLexer.g4#L93 - RESERVED_KEYWORDS = { - "account_lock", - "account_unlock", - "add", - "adddate", - "admin", - "after", - "agg_state", - "aggregate", - "alias", - "all", - "alter", - "analyze", - "analyzed", - "and", - "anti", - "append", - "array", - "array_range", - "as", - "asc", - "at", - "authors", - "auto", - "auto_increment", - "backend", - "backends", - "backup", - "begin", - "belong", - "between", - "bigint", - "bin", - "binary", - "binlog", - "bitand", - "bitmap", - "bitmap_union", - "bitor", - "bitxor", - "blob", - "boolean", - "brief", - "broker", - "buckets", - "build", - "builtin", - "bulk", - "by", - "cached", - "call", - "cancel", - "case", - "cast", - "catalog", - "catalogs", - "chain", - "char", - "character", - "charset", - "check", - "clean", - "cluster", - "clusters", - "collate", - "collation", - "collect", - "column", - "columns", - "comment", - "commit", - "committed", - "compact", - "complete", - "config", - "connection", - "connection_id", - "consistent", - "constraint", - "constraints", - "convert", - "copy", - "count", - "create", - "creation", - "cron", - "cross", - "cube", - "current", - "current_catalog", - "current_date", - "current_time", - "current_timestamp", - "current_user", - "data", - "database", - "databases", - "date", - "date_add", - "date_ceil", - "date_diff", - "date_floor", - "date_sub", - "dateadd", - "datediff", - "datetime", - "datetimev2", - "datev2", - "datetimev1", - "datev1", - "day", - "days_add", - "days_sub", - "decimal", - "decimalv2", - "decimalv3", - "decommission", - "default", - "deferred", - "delete", - "demand", - "desc", - "describe", - "diagnose", - "disk", - "distinct", - "distinctpc", - "distinctpcsa", - "distributed", - "distribution", - "div", - "do", - "doris_internal_table_id", - "double", - "drop", - "dropp", - "dual", - "duplicate", - "dynamic", - "else", - "enable", - "encryptkey", - "encryptkeys", - "end", - "ends", - "engine", - "engines", - "enter", - "errors", - "events", - "every", - "except", - "exclude", - "execute", - "exists", - "expired", - "explain", - "export", - "extended", - "external", - "extract", - "failed_login_attempts", - "false", - "fast", - "feature", - "fields", - "file", - "filter", - "first", - "float", - "follower", - "following", - "for", - "foreign", - "force", - "format", - "free", - "from", - "frontend", - "frontends", - "full", - "function", - "functions", - "generic", - "global", - "grant", - "grants", - "graph", - "group", - "grouping", - "groups", - "hash", - "having", - "hdfs", - "help", - "histogram", - "hll", - "hll_union", - "hostname", - "hour", - "hub", - "identified", - "if", - "ignore", - "immediate", - "in", - "incremental", - "index", - "indexes", - "infile", - "inner", - "insert", - "install", - "int", - "integer", - "intermediate", - "intersect", - "interval", - "into", - "inverted", - "ipv4", - "ipv6", - "is", - "is_not_null_pred", - "is_null_pred", - "isnull", - "isolation", - "job", - "jobs", - "join", - "json", - "jsonb", - "key", - "keys", - "kill", - "label", - "largeint", - "last", - "lateral", - "ldap", - "ldap_admin_password", - "left", - "less", - "level", - "like", - "limit", - "lines", - "link", - "list", - "load", - "local", - "localtime", - "localtimestamp", - "location", - "lock", - "logical", - "low_priority", - "manual", - "map", - "match", - "match_all", - "match_any", - "match_phrase", - "match_phrase_edge", - "match_phrase_prefix", - "match_regexp", - "materialized", - "max", - "maxvalue", - "memo", - "merge", - "migrate", - "migrations", - "min", - "minus", - "minute", - "modify", - "month", - "mtmv", - "name", - "names", - "natural", - "negative", - "never", - "next", - "ngram_bf", - "no", - "non_nullable", - "not", - "null", - "nulls", - "observer", - "of", - "offset", - "on", - "only", - "open", - "optimized", - "or", - "order", - "outer", - "outfile", - "over", - "overwrite", - "parameter", - "parsed", - "partition", - "partitions", - "password", - "password_expire", - "password_history", - "password_lock_time", - "password_reuse", - "path", - "pause", - "percent", - "period", - "permissive", - "physical", - "plan", - "process", - "plugin", - "plugins", - "policy", - "preceding", - "prepare", - "primary", - "proc", - "procedure", - "processlist", - "profile", - "properties", - "property", - "quantile_state", - "quantile_union", - "query", - "quota", - "random", - "range", - "read", - "real", - "rebalance", - "recover", - "recycle", - "refresh", - "references", - "regexp", - "release", - "rename", - "repair", - "repeatable", - "replace", - "replace_if_not_null", - "replica", - "repositories", - "repository", - "resource", - "resources", - "restore", - "restrictive", - "resume", - "returns", - "revoke", - "rewritten", - "right", - "rlike", - "role", - "roles", - "rollback", - "rollup", - "routine", - "row", - "rows", - "s3", - "sample", - "schedule", - "scheduler", - "schema", - "schemas", - "second", - "select", - "semi", - "sequence", - "serializable", - "session", - "set", - "sets", - "shape", - "show", - "signed", - "skew", - "smallint", - "snapshot", - "soname", - "split", - "sql_block_rule", - "start", - "starts", - "stats", - "status", - "stop", - "storage", - "stream", - "streaming", - "string", - "struct", - "subdate", - "sum", - "superuser", - "switch", - "sync", - "system", - "table", - "tables", - "tablesample", - "tablet", - "tablets", - "task", - "tasks", - "temporary", - "terminated", - "text", - "than", - "then", - "time", - "timestamp", - "timestampadd", - "timestampdiff", - "tinyint", - "to", - "transaction", - "trash", - "tree", - "triggers", - "trim", - "true", - "truncate", - "type", - "type_cast", - "types", - "unbounded", - "uncommitted", - "uninstall", - "union", - "unique", - "unlock", - "unsigned", - "update", - "use", - "user", - "using", - "value", - "values", - "varchar", - "variables", - "variant", - "vault", - "verbose", - "version", - "view", - "warnings", - "week", - "when", - "where", - "whitelist", - "with", - "work", - "workload", - "write", - "xor", - "year", - } diff --git a/altimate_packages/sqlglot/dialects/drill.py b/altimate_packages/sqlglot/dialects/drill.py deleted file mode 100644 index 8a80c2f5f..000000000 --- a/altimate_packages/sqlglot/dialects/drill.py +++ /dev/null @@ -1,157 +0,0 @@ -from __future__ import annotations - - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - Dialect, - datestrtodate_sql, - build_formatted_time, - no_trycast_sql, - rename_func, - strposition_sql, - timestrtotime_sql, -) -from sqlglot.dialects.mysql import date_add_sql -from sqlglot.transforms import preprocess, move_schema_columns_to_partitioned_by -from sqlglot.generator import unsupported_args - - -def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str: - this = self.sql(expression, "this") - time_format = self.format_time(expression) - if time_format == Drill.DATE_FORMAT: - return self.sql(exp.cast(this, exp.DataType.Type.DATE)) - return self.func("TO_DATE", this, time_format) - - -class Drill(Dialect): - NORMALIZE_FUNCTIONS: bool | str = False - PRESERVE_ORIGINAL_NAMES = True - NULL_ORDERING = "nulls_are_last" - DATE_FORMAT = "'yyyy-MM-dd'" - DATEINT_FORMAT = "'yyyyMMdd'" - TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" - SUPPORTS_USER_DEFINED_TYPES = False - SUPPORTS_SEMI_ANTI_JOIN = False - TYPED_DIVISION = True - CONCAT_COALESCE = True - - TIME_MAPPING = { - "y": "%Y", - "Y": "%Y", - "YYYY": "%Y", - "yyyy": "%Y", - "YY": "%y", - "yy": "%y", - "MMMM": "%B", - "MMM": "%b", - "MM": "%m", - "M": "%-m", - "dd": "%d", - "d": "%-d", - "HH": "%H", - "H": "%-H", - "hh": "%I", - "h": "%-I", - "mm": "%M", - "m": "%-M", - "ss": "%S", - "s": "%-S", - "SSSSSS": "%f", - "a": "%p", - "DD": "%j", - "D": "%-j", - "E": "%a", - "EE": "%a", - "EEE": "%a", - "EEEE": "%A", - "''T''": "T", - } - - class Tokenizer(tokens.Tokenizer): - IDENTIFIERS = ["`"] - STRING_ESCAPES = ["\\"] - - KEYWORDS = tokens.Tokenizer.KEYWORDS.copy() - KEYWORDS.pop("/*+") - - class Parser(parser.Parser): - STRICT_CAST = False - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "REPEATED_COUNT": exp.ArraySize.from_arg_list, - "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, - "TO_CHAR": build_formatted_time(exp.TimeToStr, "drill"), - "LEVENSHTEIN_DISTANCE": exp.Levenshtein.from_arg_list, - } - - LOG_DEFAULTS_TO_LN = True - - class Generator(generator.Generator): - JOIN_HINTS = False - TABLE_HINTS = False - QUERY_HINTS = False - NVL2_SUPPORTED = False - LAST_DAY_SUPPORTS_DATE_PART = False - SUPPORTS_CREATE_TABLE_LIKE = False - ARRAY_SIZE_NAME = "REPEATED_COUNT" - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.INT: "INTEGER", - exp.DataType.Type.SMALLINT: "INTEGER", - exp.DataType.Type.TINYINT: "INTEGER", - exp.DataType.Type.BINARY: "VARBINARY", - exp.DataType.Type.TEXT: "VARCHAR", - exp.DataType.Type.NCHAR: "VARCHAR", - exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", - exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", - exp.DataType.Type.DATETIME: "TIMESTAMP", - } - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, - } - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.ArrayContains: rename_func("REPEATED_CONTAINS"), - exp.Create: preprocess([move_schema_columns_to_partitioned_by]), - exp.DateAdd: date_add_sql("ADD"), - exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: date_add_sql("SUB"), - exp.DateToDi: lambda self, - e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.DATEINT_FORMAT}) AS INT)", - exp.DiToDate: lambda self, - e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.DATEINT_FORMAT})", - exp.If: lambda self, - e: f"`IF`({self.format_args(e.this, e.args.get('true'), e.args.get('false'))})", - exp.ILike: lambda self, e: self.binary(e, "`ILIKE`"), - exp.Levenshtein: unsupported_args("ins_cost", "del_cost", "sub_cost", "max_dist")( - rename_func("LEVENSHTEIN_DISTANCE") - ), - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", - exp.RegexpLike: rename_func("REGEXP_MATCHES"), - exp.StrToDate: _str_to_date, - exp.Pow: rename_func("POW"), - exp.Select: transforms.preprocess( - [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] - ), - exp.StrPosition: strposition_sql, - exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), - exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)), - exp.TimeStrToTime: timestrtotime_sql, - exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), - exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)), - exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), - exp.ToChar: lambda self, e: self.function_fallback_sql(e), - exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: lambda self, - e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})", - exp.TsOrDiToDi: lambda self, - e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", - } diff --git a/altimate_packages/sqlglot/dialects/druid.py b/altimate_packages/sqlglot/dialects/druid.py deleted file mode 100644 index 8b95abb57..000000000 --- a/altimate_packages/sqlglot/dialects/druid.py +++ /dev/null @@ -1,20 +0,0 @@ -from sqlglot import exp, generator -from sqlglot.dialects.dialect import rename_func, Dialect - - -class Druid(Dialect): - class Generator(generator.Generator): - # https://druid.apache.org/docs/latest/querying/sql-data-types/ - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.NCHAR: "STRING", - exp.DataType.Type.NVARCHAR: "STRING", - exp.DataType.Type.TEXT: "STRING", - exp.DataType.Type.UUID: "STRING", - } - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.Mod: rename_func("MOD"), - } diff --git a/altimate_packages/sqlglot/dialects/duckdb.py b/altimate_packages/sqlglot/dialects/duckdb.py deleted file mode 100644 index a87ddc8ce..000000000 --- a/altimate_packages/sqlglot/dialects/duckdb.py +++ /dev/null @@ -1,1159 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.expressions import DATA_TYPE -from sqlglot.dialects.dialect import ( - Dialect, - JSON_EXTRACT_TYPE, - NormalizationStrategy, - Version, - approx_count_distinct_sql, - arrow_json_extract_sql, - binary_from_function, - bool_xor_sql, - build_default_decimal_type, - count_if_to_sum, - date_trunc_to_time, - datestrtodate_sql, - no_datetime_sql, - encode_decode_sql, - build_formatted_time, - inline_array_unless_query, - no_comment_column_constraint_sql, - no_time_sql, - no_timestamp_sql, - pivot_column_names, - rename_func, - remove_from_array_using_filter, - strposition_sql, - str_to_time_sql, - timestamptrunc_sql, - timestrtotime_sql, - unit_to_var, - unit_to_str, - sha256_sql, - build_regexp_extract, - explode_to_unnest_sql, - no_make_interval_sql, - groupconcat_sql, -) -from sqlglot.generator import unsupported_args -from sqlglot.helper import seq_get -from sqlglot.tokens import TokenType -from sqlglot.parser import binary_range_parser - -DATETIME_DELTA = t.Union[ - exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd, exp.DateSub, exp.DatetimeSub -] - - -def _date_delta_sql(self: DuckDB.Generator, expression: DATETIME_DELTA) -> str: - this = expression.this - unit = unit_to_var(expression) - op = ( - "+" - if isinstance(expression, (exp.DateAdd, exp.TimeAdd, exp.DatetimeAdd, exp.TsOrDsAdd)) - else "-" - ) - - to_type: t.Optional[DATA_TYPE] = None - if isinstance(expression, exp.TsOrDsAdd): - to_type = expression.return_type - elif this.is_string: - # Cast string literals (i.e function parameters) to the appropriate type for +/- interval to work - to_type = ( - exp.DataType.Type.DATETIME - if isinstance(expression, (exp.DatetimeAdd, exp.DatetimeSub)) - else exp.DataType.Type.DATE - ) - - this = exp.cast(this, to_type) if to_type else this - - expr = expression.expression - interval = expr if isinstance(expr, exp.Interval) else exp.Interval(this=expr, unit=unit) - - return f"{self.sql(this)} {op} {self.sql(interval)}" - - -# BigQuery -> DuckDB conversion for the DATE function -def _date_sql(self: DuckDB.Generator, expression: exp.Date) -> str: - result = f"CAST({self.sql(expression, 'this')} AS DATE)" - zone = self.sql(expression, "zone") - - if zone: - date_str = self.func("STRFTIME", result, "'%d/%m/%Y'") - date_str = f"{date_str} || ' ' || {zone}" - - # This will create a TIMESTAMP with time zone information - result = self.func("STRPTIME", date_str, "'%d/%m/%Y %Z'") - - return result - - -# BigQuery -> DuckDB conversion for the TIME_DIFF function -def _timediff_sql(self: DuckDB.Generator, expression: exp.TimeDiff) -> str: - this = exp.cast(expression.this, exp.DataType.Type.TIME) - expr = exp.cast(expression.expression, exp.DataType.Type.TIME) - - # Although the 2 dialects share similar signatures, BQ seems to inverse - # the sign of the result so the start/end time operands are flipped - return self.func("DATE_DIFF", unit_to_str(expression), expr, this) - - -@unsupported_args(("expression", "DuckDB's ARRAY_SORT does not support a comparator.")) -def _array_sort_sql(self: DuckDB.Generator, expression: exp.ArraySort) -> str: - return self.func("ARRAY_SORT", expression.this) - - -def _sort_array_sql(self: DuckDB.Generator, expression: exp.SortArray) -> str: - name = "ARRAY_REVERSE_SORT" if expression.args.get("asc") == exp.false() else "ARRAY_SORT" - return self.func(name, expression.this) - - -def _build_sort_array_desc(args: t.List) -> exp.Expression: - return exp.SortArray(this=seq_get(args, 0), asc=exp.false()) - - -def _build_date_diff(args: t.List) -> exp.Expression: - return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0)) - - -def _build_generate_series(end_exclusive: bool = False) -> t.Callable[[t.List], exp.GenerateSeries]: - def _builder(args: t.List) -> exp.GenerateSeries: - # Check https://duckdb.org/docs/sql/functions/nested.html#range-functions - if len(args) == 1: - # DuckDB uses 0 as a default for the series' start when it's omitted - args.insert(0, exp.Literal.number("0")) - - gen_series = exp.GenerateSeries.from_arg_list(args) - gen_series.set("is_end_exclusive", end_exclusive) - - return gen_series - - return _builder - - -def _build_make_timestamp(args: t.List) -> exp.Expression: - if len(args) == 1: - return exp.UnixToTime(this=seq_get(args, 0), scale=exp.UnixToTime.MICROS) - - return exp.TimestampFromParts( - year=seq_get(args, 0), - month=seq_get(args, 1), - day=seq_get(args, 2), - hour=seq_get(args, 3), - min=seq_get(args, 4), - sec=seq_get(args, 5), - ) - - -def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[DuckDB.Parser], exp.Show]: - def _parse(self: DuckDB.Parser) -> exp.Show: - return self._parse_show_duckdb(*args, **kwargs) - - return _parse - - -def _struct_sql(self: DuckDB.Generator, expression: exp.Struct) -> str: - args: t.List[str] = [] - - # BigQuery allows inline construction such as "STRUCT('str', 1)" which is - # canonicalized to "ROW('str', 1) AS STRUCT(a TEXT, b INT)" in DuckDB - # The transformation to ROW will take place if: - # 1. The STRUCT itself does not have proper fields (key := value) as a "proper" STRUCT would - # 2. A cast to STRUCT / ARRAY of STRUCTs is found - ancestor_cast = expression.find_ancestor(exp.Cast) - is_bq_inline_struct = ( - (expression.find(exp.PropertyEQ) is None) - and ancestor_cast - and any( - casted_type.is_type(exp.DataType.Type.STRUCT) - for casted_type in ancestor_cast.find_all(exp.DataType) - ) - ) - - for i, expr in enumerate(expression.expressions): - is_property_eq = isinstance(expr, exp.PropertyEQ) - value = expr.expression if is_property_eq else expr - - if is_bq_inline_struct: - args.append(self.sql(value)) - else: - key = expr.name if is_property_eq else f"_{i}" - args.append(f"{self.sql(exp.Literal.string(key))}: {self.sql(value)}") - - csv_args = ", ".join(args) - - return f"ROW({csv_args})" if is_bq_inline_struct else f"{{{csv_args}}}" - - -def _datatype_sql(self: DuckDB.Generator, expression: exp.DataType) -> str: - if expression.is_type("array"): - return f"{self.expressions(expression, flat=True)}[{self.expressions(expression, key='values', flat=True)}]" - - # Modifiers are not supported for TIME, [TIME | TIMESTAMP] WITH TIME ZONE - if expression.is_type( - exp.DataType.Type.TIME, exp.DataType.Type.TIMETZ, exp.DataType.Type.TIMESTAMPTZ - ): - return expression.this.value - - return self.datatype_sql(expression) - - -def _json_format_sql(self: DuckDB.Generator, expression: exp.JSONFormat) -> str: - sql = self.func("TO_JSON", expression.this, expression.args.get("options")) - return f"CAST({sql} AS TEXT)" - - -def _unix_to_time_sql(self: DuckDB.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = expression.this - - if scale in (None, exp.UnixToTime.SECONDS): - return self.func("TO_TIMESTAMP", timestamp) - if scale == exp.UnixToTime.MILLIS: - return self.func("EPOCH_MS", timestamp) - if scale == exp.UnixToTime.MICROS: - return self.func("MAKE_TIMESTAMP", timestamp) - - return self.func("TO_TIMESTAMP", exp.Div(this=timestamp, expression=exp.func("POW", 10, scale))) - - -WRAPPED_JSON_EXTRACT_EXPRESSIONS = (exp.Binary, exp.Bracket, exp.In) - - -def _arrow_json_extract_sql(self: DuckDB.Generator, expression: JSON_EXTRACT_TYPE) -> str: - arrow_sql = arrow_json_extract_sql(self, expression) - if not expression.same_parent and isinstance( - expression.parent, WRAPPED_JSON_EXTRACT_EXPRESSIONS - ): - arrow_sql = self.wrap(arrow_sql) - return arrow_sql - - -def _implicit_datetime_cast( - arg: t.Optional[exp.Expression], type: exp.DataType.Type = exp.DataType.Type.DATE -) -> t.Optional[exp.Expression]: - return exp.cast(arg, type) if isinstance(arg, exp.Literal) else arg - - -def _date_diff_sql(self: DuckDB.Generator, expression: exp.DateDiff) -> str: - this = _implicit_datetime_cast(expression.this) - expr = _implicit_datetime_cast(expression.expression) - - return self.func("DATE_DIFF", unit_to_str(expression), expr, this) - - -def _generate_datetime_array_sql( - self: DuckDB.Generator, expression: t.Union[exp.GenerateDateArray, exp.GenerateTimestampArray] -) -> str: - is_generate_date_array = isinstance(expression, exp.GenerateDateArray) - - type = exp.DataType.Type.DATE if is_generate_date_array else exp.DataType.Type.TIMESTAMP - start = _implicit_datetime_cast(expression.args.get("start"), type=type) - end = _implicit_datetime_cast(expression.args.get("end"), type=type) - - # BQ's GENERATE_DATE_ARRAY & GENERATE_TIMESTAMP_ARRAY are transformed to DuckDB'S GENERATE_SERIES - gen_series: t.Union[exp.GenerateSeries, exp.Cast] = exp.GenerateSeries( - start=start, end=end, step=expression.args.get("step") - ) - - if is_generate_date_array: - # The GENERATE_SERIES result type is TIMESTAMP array, so to match BQ's semantics for - # GENERATE_DATE_ARRAY we must cast it back to DATE array - gen_series = exp.cast(gen_series, exp.DataType.build("ARRAY")) - - return self.sql(gen_series) - - -def _json_extract_value_array_sql( - self: DuckDB.Generator, expression: exp.JSONValueArray | exp.JSONExtractArray -) -> str: - json_extract = exp.JSONExtract(this=expression.this, expression=expression.expression) - data_type = "ARRAY" if isinstance(expression, exp.JSONValueArray) else "ARRAY" - return self.sql(exp.cast(json_extract, to=exp.DataType.build(data_type))) - - -class DuckDB(Dialect): - NULL_ORDERING = "nulls_are_last" - SUPPORTS_USER_DEFINED_TYPES = True - SAFE_DIVISION = True - INDEX_OFFSET = 1 - CONCAT_COALESCE = True - SUPPORTS_ORDER_BY_ALL = True - SUPPORTS_FIXED_SIZE_ARRAYS = True - STRICT_JSON_PATH_SYNTAX = False - NUMBERS_CAN_BE_UNDERSCORE_SEPARATED = True - - # https://duckdb.org/docs/sql/introduction.html#creating-a-new-table - NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE - - def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - if isinstance(path, exp.Literal): - # DuckDB also supports the JSON pointer syntax, where every path starts with a `/`. - # Additionally, it allows accessing the back of lists using the `[#-i]` syntax. - # This check ensures we'll avoid trying to parse these as JSON paths, which can - # either result in a noisy warning or in an invalid representation of the path. - path_text = path.name - if path_text.startswith("/") or "[#" in path_text: - return path - - return super().to_json_path(path) - - class Tokenizer(tokens.Tokenizer): - BYTE_STRINGS = [("e'", "'"), ("E'", "'")] - HEREDOC_STRINGS = ["$"] - - HEREDOC_TAG_IS_IDENTIFIER = True - HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - "//": TokenType.DIV, - "**": TokenType.DSTAR, - "^@": TokenType.CARET_AT, - "@>": TokenType.AT_GT, - "<@": TokenType.LT_AT, - "ATTACH": TokenType.ATTACH, - "BINARY": TokenType.VARBINARY, - "BITSTRING": TokenType.BIT, - "BPCHAR": TokenType.TEXT, - "CHAR": TokenType.TEXT, - "DATETIME": TokenType.TIMESTAMPNTZ, - "DETACH": TokenType.DETACH, - "EXCLUDE": TokenType.EXCEPT, - "LOGICAL": TokenType.BOOLEAN, - "ONLY": TokenType.ONLY, - "PIVOT_WIDER": TokenType.PIVOT, - "POSITIONAL": TokenType.POSITIONAL, - "SIGNED": TokenType.INT, - "STRING": TokenType.TEXT, - "SUMMARIZE": TokenType.SUMMARIZE, - "TIMESTAMP": TokenType.TIMESTAMPNTZ, - "TIMESTAMP_S": TokenType.TIMESTAMP_S, - "TIMESTAMP_MS": TokenType.TIMESTAMP_MS, - "TIMESTAMP_NS": TokenType.TIMESTAMP_NS, - "TIMESTAMP_US": TokenType.TIMESTAMP, - "UBIGINT": TokenType.UBIGINT, - "UINTEGER": TokenType.UINT, - "USMALLINT": TokenType.USMALLINT, - "UTINYINT": TokenType.UTINYINT, - "VARCHAR": TokenType.TEXT, - } - KEYWORDS.pop("/*+") - - SINGLE_TOKENS = { - **tokens.Tokenizer.SINGLE_TOKENS, - "$": TokenType.PARAMETER, - } - - COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} - - class Parser(parser.Parser): - BITWISE = { - **parser.Parser.BITWISE, - TokenType.TILDA: exp.RegexpLike, - } - BITWISE.pop(TokenType.CARET) - - RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, - TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps), - TokenType.CARET_AT: binary_range_parser(exp.StartsWith), - } - - EXPONENT = { - **parser.Parser.EXPONENT, - TokenType.CARET: exp.Pow, - TokenType.DSTAR: exp.Pow, - } - - FUNCTIONS_WITH_ALIASED_ARGS = {*parser.Parser.FUNCTIONS_WITH_ALIASED_ARGS, "STRUCT_PACK"} - - SHOW_PARSERS = { - "TABLES": _show_parser("TABLES"), - "ALL TABLES": _show_parser("ALL TABLES"), - } - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "ARRAY_REVERSE_SORT": _build_sort_array_desc, - "ARRAY_SORT": exp.SortArray.from_arg_list, - "DATEDIFF": _build_date_diff, - "DATE_DIFF": _build_date_diff, - "DATE_TRUNC": date_trunc_to_time, - "DATETRUNC": date_trunc_to_time, - "DECODE": lambda args: exp.Decode( - this=seq_get(args, 0), charset=exp.Literal.string("utf-8") - ), - "EDITDIST3": exp.Levenshtein.from_arg_list, - "ENCODE": lambda args: exp.Encode( - this=seq_get(args, 0), charset=exp.Literal.string("utf-8") - ), - "EPOCH": exp.TimeToUnix.from_arg_list, - "EPOCH_MS": lambda args: exp.UnixToTime( - this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS - ), - "GENERATE_SERIES": _build_generate_series(), - "JSON": exp.ParseJSON.from_arg_list, - "JSON_EXTRACT_PATH": parser.build_extract_json_with_path(exp.JSONExtract), - "JSON_EXTRACT_STRING": parser.build_extract_json_with_path(exp.JSONExtractScalar), - "LIST_HAS": exp.ArrayContains.from_arg_list, - "LIST_REVERSE_SORT": _build_sort_array_desc, - "LIST_SORT": exp.SortArray.from_arg_list, - "LIST_VALUE": lambda args: exp.Array(expressions=args), - "MAKE_TIME": exp.TimeFromParts.from_arg_list, - "MAKE_TIMESTAMP": _build_make_timestamp, - "QUANTILE_CONT": exp.PercentileCont.from_arg_list, - "QUANTILE_DISC": exp.PercentileDisc.from_arg_list, - "RANGE": _build_generate_series(end_exclusive=True), - "REGEXP_EXTRACT": build_regexp_extract(exp.RegexpExtract), - "REGEXP_EXTRACT_ALL": build_regexp_extract(exp.RegexpExtractAll), - "REGEXP_MATCHES": exp.RegexpLike.from_arg_list, - "REGEXP_REPLACE": lambda args: exp.RegexpReplace( - this=seq_get(args, 0), - expression=seq_get(args, 1), - replacement=seq_get(args, 2), - modifiers=seq_get(args, 3), - ), - "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), - "STRFTIME": build_formatted_time(exp.TimeToStr, "duckdb"), - "STRING_SPLIT": exp.Split.from_arg_list, - "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, - "STRING_TO_ARRAY": exp.Split.from_arg_list, - "STRPTIME": build_formatted_time(exp.StrToTime, "duckdb"), - "STRUCT_PACK": exp.Struct.from_arg_list, - "STR_SPLIT": exp.Split.from_arg_list, - "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, - "TIME_BUCKET": exp.DateBin.from_arg_list, - "TO_TIMESTAMP": exp.UnixToTime.from_arg_list, - "UNNEST": exp.Explode.from_arg_list, - "XOR": binary_from_function(exp.BitwiseXor), - } - - FUNCTIONS.pop("DATE_SUB") - FUNCTIONS.pop("GLOB") - - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, - **dict.fromkeys( - ("GROUP_CONCAT", "LISTAGG", "STRINGAGG"), lambda self: self._parse_string_agg() - ), - } - FUNCTION_PARSERS.pop("DECODE") - - NO_PAREN_FUNCTION_PARSERS = { - **parser.Parser.NO_PAREN_FUNCTION_PARSERS, - "MAP": lambda self: self._parse_map(), - "@": lambda self: exp.Abs(this=self._parse_bitwise()), - } - - TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { - TokenType.SEMI, - TokenType.ANTI, - } - - PLACEHOLDER_PARSERS = { - **parser.Parser.PLACEHOLDER_PARSERS, - TokenType.PARAMETER: lambda self: ( - self.expression(exp.Placeholder, this=self._prev.text) - if self._match(TokenType.NUMBER) or self._match_set(self.ID_VAR_TOKENS) - else None - ), - } - - TYPE_CONVERTERS = { - # https://duckdb.org/docs/sql/data_types/numeric - exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=18, scale=3), - # https://duckdb.org/docs/sql/data_types/text - exp.DataType.Type.TEXT: lambda dtype: exp.DataType.build("TEXT"), - } - - STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, - TokenType.ATTACH: lambda self: self._parse_attach_detach(), - TokenType.DETACH: lambda self: self._parse_attach_detach(is_attach=False), - TokenType.SHOW: lambda self: self._parse_show(), - } - - def _parse_expression(self) -> t.Optional[exp.Expression]: - # DuckDB supports prefix aliases, e.g. foo: 1 - if self._next and self._next.token_type == TokenType.COLON: - alias = self._parse_id_var(tokens=self.ALIAS_TOKENS) - self._match(TokenType.COLON) - comments = self._prev_comments or [] - - this = self._parse_assignment() - if isinstance(this, exp.Expression): - # Moves the comment next to the alias in `alias: expr /* comment */` - comments += this.pop_comments() or [] - - return self.expression(exp.Alias, comments=comments, this=this, alias=alias) - - return super()._parse_expression() - - def _parse_table( - self, - schema: bool = False, - joins: bool = False, - alias_tokens: t.Optional[t.Collection[TokenType]] = None, - parse_bracket: bool = False, - is_db_reference: bool = False, - parse_partition: bool = False, - ) -> t.Optional[exp.Expression]: - # DuckDB supports prefix aliases, e.g. FROM foo: bar - if self._next and self._next.token_type == TokenType.COLON: - alias = self._parse_table_alias( - alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS - ) - self._match(TokenType.COLON) - comments = self._prev_comments or [] - else: - alias = None - comments = [] - - table = super()._parse_table( - schema=schema, - joins=joins, - alias_tokens=alias_tokens, - parse_bracket=parse_bracket, - is_db_reference=is_db_reference, - parse_partition=parse_partition, - ) - if isinstance(table, exp.Expression) and isinstance(alias, exp.TableAlias): - # Moves the comment next to the alias in `alias: table /* comment */` - comments += table.pop_comments() or [] - alias.comments = alias.pop_comments() + comments - table.set("alias", alias) - - return table - - def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]: - # https://duckdb.org/docs/sql/samples.html - sample = super()._parse_table_sample(as_modifier=as_modifier) - if sample and not sample.args.get("method"): - if sample.args.get("size"): - sample.set("method", exp.var("RESERVOIR")) - else: - sample.set("method", exp.var("SYSTEM")) - - return sample - - def _parse_bracket( - self, this: t.Optional[exp.Expression] = None - ) -> t.Optional[exp.Expression]: - bracket = super()._parse_bracket(this) - - if self.dialect.version < Version("1.2.0") and isinstance(bracket, exp.Bracket): - # https://duckdb.org/2025/02/05/announcing-duckdb-120.html#breaking-changes - bracket.set("returns_list_for_maps", True) - - return bracket - - def _parse_map(self) -> exp.ToMap | exp.Map: - if self._match(TokenType.L_BRACE, advance=False): - return self.expression(exp.ToMap, this=self._parse_bracket()) - - args = self._parse_wrapped_csv(self._parse_assignment) - return self.expression(exp.Map, keys=seq_get(args, 0), values=seq_get(args, 1)) - - def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]: - return self._parse_field_def() - - def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: - if len(aggregations) == 1: - return super()._pivot_column_names(aggregations) - return pivot_column_names(aggregations, dialect="duckdb") - - def _parse_attach_detach(self, is_attach=True) -> exp.Attach | exp.Detach: - def _parse_attach_option() -> exp.AttachOption: - return self.expression( - exp.AttachOption, - this=self._parse_var(any_token=True), - expression=self._parse_field(any_token=True), - ) - - self._match(TokenType.DATABASE) - exists = self._parse_exists(not_=is_attach) - this = self._parse_alias(self._parse_primary_or_var(), explicit=True) - - if self._match(TokenType.L_PAREN, advance=False): - expressions = self._parse_wrapped_csv(_parse_attach_option) - else: - expressions = None - - return ( - self.expression(exp.Attach, this=this, exists=exists, expressions=expressions) - if is_attach - else self.expression(exp.Detach, this=this, exists=exists) - ) - - def _parse_show_duckdb(self, this: str) -> exp.Show: - return self.expression(exp.Show, this=this) - - class Generator(generator.Generator): - PARAMETER_TOKEN = "$" - NAMED_PLACEHOLDER_TOKEN = "$" - JOIN_HINTS = False - TABLE_HINTS = False - QUERY_HINTS = False - LIMIT_FETCH = "LIMIT" - STRUCT_DELIMITER = ("(", ")") - RENAME_TABLE_WITH_DB = False - NVL2_SUPPORTED = False - SEMI_ANTI_JOIN_WITH_SIDE = False - TABLESAMPLE_KEYWORDS = "USING SAMPLE" - TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" - LAST_DAY_SUPPORTS_DATE_PART = False - JSON_KEY_VALUE_PAIR_SEP = "," - IGNORE_NULLS_IN_FUNC = True - JSON_PATH_BRACKETED_KEY_SUPPORTED = False - SUPPORTS_CREATE_TABLE_LIKE = False - MULTI_ARG_DISTINCT = False - CAN_IMPLEMENT_ARRAY_ANY = True - SUPPORTS_TO_NUMBER = False - SUPPORTS_WINDOW_EXCLUDE = True - COPY_HAS_INTO_KEYWORD = False - STAR_EXCEPT = "EXCLUDE" - PAD_FILL_PATTERN_IS_REQUIRED = True - ARRAY_CONCAT_IS_VAR_LEN = False - ARRAY_SIZE_DIM_REQUIRED = False - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: inline_array_unless_query, - exp.ArrayFilter: rename_func("LIST_FILTER"), - exp.ArrayRemove: remove_from_array_using_filter, - exp.ArraySort: _array_sort_sql, - exp.ArraySum: rename_func("LIST_SUM"), - exp.BitwiseXor: rename_func("XOR"), - exp.CommentColumnConstraint: no_comment_column_constraint_sql, - exp.CurrentDate: lambda *_: "CURRENT_DATE", - exp.CurrentTime: lambda *_: "CURRENT_TIME", - exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.DayOfMonth: rename_func("DAYOFMONTH"), - exp.DayOfWeek: rename_func("DAYOFWEEK"), - exp.DayOfWeekIso: rename_func("ISODOW"), - exp.DayOfYear: rename_func("DAYOFYEAR"), - exp.DataType: _datatype_sql, - exp.Date: _date_sql, - exp.DateAdd: _date_delta_sql, - exp.DateFromParts: rename_func("MAKE_DATE"), - exp.DateSub: _date_delta_sql, - exp.DateDiff: _date_diff_sql, - exp.DateStrToDate: datestrtodate_sql, - exp.Datetime: no_datetime_sql, - exp.DatetimeSub: _date_delta_sql, - exp.DatetimeAdd: _date_delta_sql, - exp.DateToDi: lambda self, - e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.DATEINT_FORMAT}) AS INT)", - exp.Decode: lambda self, e: encode_decode_sql(self, e, "DECODE", replace=False), - exp.DiToDate: lambda self, - e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.DATEINT_FORMAT}) AS DATE)", - exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False), - exp.GenerateDateArray: _generate_datetime_array_sql, - exp.GenerateTimestampArray: _generate_datetime_array_sql, - exp.GroupConcat: lambda self, e: groupconcat_sql(self, e, within_group=False), - exp.HexString: lambda self, e: self.hexstring_sql(e, binary_function_repr="FROM_HEX"), - exp.Explode: rename_func("UNNEST"), - exp.IntDiv: lambda self, e: self.binary(e, "//"), - exp.IsInf: rename_func("ISINF"), - exp.IsNan: rename_func("ISNAN"), - exp.JSONBExists: rename_func("JSON_EXISTS"), - exp.JSONExtract: _arrow_json_extract_sql, - exp.JSONExtractArray: _json_extract_value_array_sql, - exp.JSONExtractScalar: _arrow_json_extract_sql, - exp.JSONFormat: _json_format_sql, - exp.JSONValueArray: _json_extract_value_array_sql, - exp.Lateral: explode_to_unnest_sql, - exp.LogicalOr: rename_func("BOOL_OR"), - exp.LogicalAnd: rename_func("BOOL_AND"), - exp.MakeInterval: lambda self, e: no_make_interval_sql(self, e, sep=" "), - exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)), - exp.MonthsBetween: lambda self, e: self.func( - "DATEDIFF", - "'month'", - exp.cast(e.expression, exp.DataType.Type.TIMESTAMP, copy=True), - exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True), - ), - exp.PercentileCont: rename_func("QUANTILE_CONT"), - exp.PercentileDisc: rename_func("QUANTILE_DISC"), - # DuckDB doesn't allow qualified columns inside of PIVOT expressions. - # See: https://github.com/duckdb/duckdb/blob/671faf92411182f81dce42ac43de8bfb05d9909e/src/planner/binder/tableref/bind_pivot.cpp#L61-L62 - exp.Pivot: transforms.preprocess([transforms.unqualify_columns]), - exp.RegexpReplace: lambda self, e: self.func( - "REGEXP_REPLACE", - e.this, - e.expression, - e.args.get("replacement"), - e.args.get("modifiers"), - ), - exp.RegexpLike: rename_func("REGEXP_MATCHES"), - exp.RegexpILike: lambda self, e: self.func( - "REGEXP_MATCHES", e.this, e.expression, exp.Literal.string("i") - ), - exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), - exp.Return: lambda self, e: self.sql(e, "this"), - exp.ReturnsProperty: lambda self, e: "TABLE" if isinstance(e.this, exp.Schema) else "", - exp.Rand: rename_func("RANDOM"), - exp.SHA: rename_func("SHA1"), - exp.SHA2: sha256_sql, - exp.Split: rename_func("STR_SPLIT"), - exp.SortArray: _sort_array_sql, - exp.StrPosition: strposition_sql, - exp.StrToUnix: lambda self, e: self.func( - "EPOCH", self.func("STRPTIME", e.this, self.format_time(e)) - ), - exp.Struct: _struct_sql, - exp.Transform: rename_func("LIST_TRANSFORM"), - exp.TimeAdd: _date_delta_sql, - exp.Time: no_time_sql, - exp.TimeDiff: _timediff_sql, - exp.Timestamp: no_timestamp_sql, - exp.TimestampDiff: lambda self, e: self.func( - "DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this - ), - exp.TimestampTrunc: timestamptrunc_sql(), - exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)), - exp.TimeStrToTime: timestrtotime_sql, - exp.TimeStrToUnix: lambda self, e: self.func( - "EPOCH", exp.cast(e.this, exp.DataType.Type.TIMESTAMP) - ), - exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.this, self.format_time(e)), - exp.TimeToUnix: rename_func("EPOCH"), - exp.TsOrDiToDi: lambda self, - e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", - exp.TsOrDsAdd: _date_delta_sql, - exp.TsOrDsDiff: lambda self, e: self.func( - "DATE_DIFF", - f"'{e.args.get('unit') or 'DAY'}'", - exp.cast(e.expression, exp.DataType.Type.TIMESTAMP), - exp.cast(e.this, exp.DataType.Type.TIMESTAMP), - ), - exp.UnixToStr: lambda self, e: self.func( - "STRFTIME", self.func("TO_TIMESTAMP", e.this), self.format_time(e) - ), - exp.DatetimeTrunc: lambda self, e: self.func( - "DATE_TRUNC", unit_to_str(e), exp.cast(e.this, exp.DataType.Type.DATETIME) - ), - exp.UnixToTime: _unix_to_time_sql, - exp.UnixToTimeStr: lambda self, e: f"CAST(TO_TIMESTAMP({self.sql(e, 'this')}) AS TEXT)", - exp.VariancePop: rename_func("VAR_POP"), - exp.WeekOfYear: rename_func("WEEKOFYEAR"), - exp.Xor: bool_xor_sql, - exp.Levenshtein: unsupported_args("ins_cost", "del_cost", "sub_cost", "max_dist")( - rename_func("LEVENSHTEIN") - ), - exp.JSONObjectAgg: rename_func("JSON_GROUP_OBJECT"), - exp.JSONBObjectAgg: rename_func("JSON_GROUP_OBJECT"), - exp.DateBin: rename_func("TIME_BUCKET"), - } - - SUPPORTED_JSON_PATH_PARTS = { - exp.JSONPathKey, - exp.JSONPathRoot, - exp.JSONPathSubscript, - exp.JSONPathWildcard, - } - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.BINARY: "BLOB", - exp.DataType.Type.BPCHAR: "TEXT", - exp.DataType.Type.CHAR: "TEXT", - exp.DataType.Type.DATETIME: "TIMESTAMP", - exp.DataType.Type.FLOAT: "REAL", - exp.DataType.Type.JSONB: "JSON", - exp.DataType.Type.NCHAR: "TEXT", - exp.DataType.Type.NVARCHAR: "TEXT", - exp.DataType.Type.UINT: "UINTEGER", - exp.DataType.Type.VARBINARY: "BLOB", - exp.DataType.Type.ROWVERSION: "BLOB", - exp.DataType.Type.VARCHAR: "TEXT", - exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP", - exp.DataType.Type.TIMESTAMP_S: "TIMESTAMP_S", - exp.DataType.Type.TIMESTAMP_MS: "TIMESTAMP_MS", - exp.DataType.Type.TIMESTAMP_NS: "TIMESTAMP_NS", - } - - # https://github.com/duckdb/duckdb/blob/ff7f24fd8e3128d94371827523dae85ebaf58713/third_party/libpg_query/grammar/keywords/reserved_keywords.list#L1-L77 - RESERVED_KEYWORDS = { - "array", - "analyse", - "union", - "all", - "when", - "in_p", - "default", - "create_p", - "window", - "asymmetric", - "to", - "else", - "localtime", - "from", - "end_p", - "select", - "current_date", - "foreign", - "with", - "grant", - "session_user", - "or", - "except", - "references", - "fetch", - "limit", - "group_p", - "leading", - "into", - "collate", - "offset", - "do", - "then", - "localtimestamp", - "check_p", - "lateral_p", - "current_role", - "where", - "asc_p", - "placing", - "desc_p", - "user", - "unique", - "initially", - "column", - "both", - "some", - "as", - "any", - "only", - "deferrable", - "null_p", - "current_time", - "true_p", - "table", - "case", - "trailing", - "variadic", - "for", - "on", - "distinct", - "false_p", - "not", - "constraint", - "current_timestamp", - "returning", - "primary", - "intersect", - "having", - "analyze", - "current_user", - "and", - "cast", - "symmetric", - "using", - "order", - "current_catalog", - } - - UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren) - - # DuckDB doesn't generally support CREATE TABLE .. properties - # https://duckdb.org/docs/sql/statements/create_table.html - PROPERTIES_LOCATION = { - prop: exp.Properties.Location.UNSUPPORTED - for prop in generator.Generator.PROPERTIES_LOCATION - } - - # There are a few exceptions (e.g. temporary tables) which are supported or - # can be transpiled to DuckDB, so we explicitly override them accordingly - PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA - PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE - PROPERTIES_LOCATION[exp.ReturnsProperty] = exp.Properties.Location.POST_ALIAS - - IGNORE_RESPECT_NULLS_WINDOW_FUNCTIONS = ( - exp.FirstValue, - exp.Lag, - exp.LastValue, - exp.Lead, - exp.NthValue, - ) - - def show_sql(self, expression: exp.Show) -> str: - return f"SHOW {expression.name}" - - def fromiso8601timestamp_sql(self, expression: exp.FromISO8601Timestamp) -> str: - return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ)) - - def strtotime_sql(self, expression: exp.StrToTime) -> str: - if expression.args.get("safe"): - formatted_time = self.format_time(expression) - return f"CAST({self.func('TRY_STRPTIME', expression.this, formatted_time)} AS TIMESTAMP)" - return str_to_time_sql(self, expression) - - def strtodate_sql(self, expression: exp.StrToDate) -> str: - if expression.args.get("safe"): - formatted_time = self.format_time(expression) - return f"CAST({self.func('TRY_STRPTIME', expression.this, formatted_time)} AS DATE)" - return f"CAST({str_to_time_sql(self, expression)} AS DATE)" - - def parsejson_sql(self, expression: exp.ParseJSON) -> str: - arg = expression.this - if expression.args.get("safe"): - return self.sql(exp.case().when(exp.func("json_valid", arg), arg).else_(exp.null())) - return self.func("JSON", arg) - - def timefromparts_sql(self, expression: exp.TimeFromParts) -> str: - nano = expression.args.get("nano") - if nano is not None: - expression.set( - "sec", expression.args["sec"] + nano.pop() / exp.Literal.number(1000000000.0) - ) - - return rename_func("MAKE_TIME")(self, expression) - - def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: - sec = expression.args["sec"] - - milli = expression.args.get("milli") - if milli is not None: - sec += milli.pop() / exp.Literal.number(1000.0) - - nano = expression.args.get("nano") - if nano is not None: - sec += nano.pop() / exp.Literal.number(1000000000.0) - - if milli or nano: - expression.set("sec", sec) - - return rename_func("MAKE_TIMESTAMP")(self, expression) - - def tablesample_sql( - self, - expression: exp.TableSample, - tablesample_keyword: t.Optional[str] = None, - ) -> str: - if not isinstance(expression.parent, exp.Select): - # This sample clause only applies to a single source, not the entire resulting relation - tablesample_keyword = "TABLESAMPLE" - - if expression.args.get("size"): - method = expression.args.get("method") - if method and method.name.upper() != "RESERVOIR": - self.unsupported( - f"Sampling method {method} is not supported with a discrete sample count, " - "defaulting to reservoir sampling" - ) - expression.set("method", exp.var("RESERVOIR")) - - return super().tablesample_sql(expression, tablesample_keyword=tablesample_keyword) - - def interval_sql(self, expression: exp.Interval) -> str: - multiplier: t.Optional[int] = None - unit = expression.text("unit").lower() - - if unit.startswith("week"): - multiplier = 7 - if unit.startswith("quarter"): - multiplier = 90 - - if multiplier: - return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('DAY')))})" - - return super().interval_sql(expression) - - def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: - if isinstance(expression.parent, exp.UserDefinedFunction): - return self.sql(expression, "this") - return super().columndef_sql(expression, sep) - - def join_sql(self, expression: exp.Join) -> str: - if ( - expression.side == "LEFT" - and not expression.args.get("on") - and isinstance(expression.this, exp.Unnest) - ): - # Some dialects support `LEFT JOIN UNNEST(...)` without an explicit ON clause - # DuckDB doesn't, but we can just add a dummy ON clause that is always true - return super().join_sql(expression.on(exp.true())) - - return super().join_sql(expression) - - def generateseries_sql(self, expression: exp.GenerateSeries) -> str: - # GENERATE_SERIES(a, b) -> [a, b], RANGE(a, b) -> [a, b) - if expression.args.get("is_end_exclusive"): - return rename_func("RANGE")(self, expression) - - return self.function_fallback_sql(expression) - - def countif_sql(self, expression: exp.CountIf) -> str: - if self.dialect.version >= Version("1.2"): - return self.function_fallback_sql(expression) - - # https://github.com/tobymao/sqlglot/pull/4749 - return count_if_to_sum(self, expression) - - def bracket_sql(self, expression: exp.Bracket) -> str: - if self.dialect.version >= Version("1.2"): - return super().bracket_sql(expression) - - # https://duckdb.org/2025/02/05/announcing-duckdb-120.html#breaking-changes - this = expression.this - if isinstance(this, exp.Array): - this.replace(exp.paren(this)) - - bracket = super().bracket_sql(expression) - - if not expression.args.get("returns_list_for_maps"): - if not this.type: - from sqlglot.optimizer.annotate_types import annotate_types - - this = annotate_types(this, dialect=self.dialect) - - if this.is_type(exp.DataType.Type.MAP): - bracket = f"({bracket})[1]" - - return bracket - - def withingroup_sql(self, expression: exp.WithinGroup) -> str: - expression_sql = self.sql(expression, "expression") - - func = expression.this - if isinstance(func, exp.PERCENTILES): - # Make the order key the first arg and slide the fraction to the right - # https://duckdb.org/docs/sql/aggregates#ordered-set-aggregate-functions - order_col = expression.find(exp.Ordered) - if order_col: - func.set("expression", func.this) - func.set("this", order_col.this) - - this = self.sql(expression, "this").rstrip(")") - - return f"{this}{expression_sql})" - - def length_sql(self, expression: exp.Length) -> str: - arg = expression.this - - # Dialects like BQ and Snowflake also accept binary values as args, so - # DDB will attempt to infer the type or resort to case/when resolution - if not expression.args.get("binary") or arg.is_string: - return self.func("LENGTH", arg) - - if not arg.type: - from sqlglot.optimizer.annotate_types import annotate_types - - arg = annotate_types(arg, dialect=self.dialect) - - if arg.is_type(*exp.DataType.TEXT_TYPES): - return self.func("LENGTH", arg) - - # We need these casts to make duckdb's static type checker happy - blob = exp.cast(arg, exp.DataType.Type.VARBINARY) - varchar = exp.cast(arg, exp.DataType.Type.VARCHAR) - - case = ( - exp.case(self.func("TYPEOF", arg)) - .when("'BLOB'", self.func("OCTET_LENGTH", blob)) - .else_( - exp.Anonymous(this="LENGTH", expressions=[varchar]) - ) # anonymous to break length_sql recursion - ) - - return self.sql(case) - - def objectinsert_sql(self, expression: exp.ObjectInsert) -> str: - this = expression.this - key = expression.args.get("key") - key_sql = key.name if isinstance(key, exp.Expression) else "" - value_sql = self.sql(expression, "value") - - kv_sql = f"{key_sql} := {value_sql}" - - # If the input struct is empty e.g. transpiling OBJECT_INSERT(OBJECT_CONSTRUCT(), key, value) from Snowflake - # then we can generate STRUCT_PACK which will build it since STRUCT_INSERT({}, key := value) is not valid DuckDB - if isinstance(this, exp.Struct) and not this.expressions: - return self.func("STRUCT_PACK", kv_sql) - - return self.func("STRUCT_INSERT", this, kv_sql) - - def unnest_sql(self, expression: exp.Unnest) -> str: - explode_array = expression.args.get("explode_array") - if explode_array: - # In BigQuery, UNNESTing a nested array leads to explosion of the top-level array & struct - # This is transpiled to DDB by transforming "FROM UNNEST(...)" to "FROM (SELECT UNNEST(..., max_depth => 2))" - expression.expressions.append( - exp.Kwarg(this=exp.var("max_depth"), expression=exp.Literal.number(2)) - ) - - # If BQ's UNNEST is aliased, we transform it from a column alias to a table alias in DDB - alias = expression.args.get("alias") - if alias: - expression.set("alias", None) - alias = exp.TableAlias(this=seq_get(alias.args.get("columns"), 0)) - - unnest_sql = super().unnest_sql(expression) - select = exp.Select(expressions=[unnest_sql]).subquery(alias) - return self.sql(select) - - return super().unnest_sql(expression) - - def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str: - if isinstance(expression.this, self.IGNORE_RESPECT_NULLS_WINDOW_FUNCTIONS): - # DuckDB should render IGNORE NULLS only for the general-purpose - # window functions that accept it e.g. FIRST_VALUE(... IGNORE NULLS) OVER (...) - return super().ignorenulls_sql(expression) - - self.unsupported("IGNORE NULLS is not supported for non-window functions.") - return self.sql(expression, "this") - - def respectnulls_sql(self, expression: exp.RespectNulls) -> str: - if isinstance(expression.this, self.IGNORE_RESPECT_NULLS_WINDOW_FUNCTIONS): - # DuckDB should render RESPECT NULLS only for the general-purpose - # window functions that accept it e.g. FIRST_VALUE(... RESPECT NULLS) OVER (...) - return super().respectnulls_sql(expression) - - self.unsupported("RESPECT NULLS is not supported for non-window functions.") - return self.sql(expression, "this") - - def arraytostring_sql(self, expression: exp.ArrayToString) -> str: - this = self.sql(expression, "this") - null_text = self.sql(expression, "null") - - if null_text: - this = f"LIST_TRANSFORM({this}, x -> COALESCE(x, {null_text}))" - - return self.func("ARRAY_TO_STRING", this, expression.expression) - - @unsupported_args("position", "occurrence") - def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: - group = expression.args.get("group") - params = expression.args.get("parameters") - - # Do not render group if there is no following argument, - # and it's the default value for this dialect - if ( - not params - and group - and group.name == str(self.dialect.REGEXP_EXTRACT_DEFAULT_GROUP) - ): - group = None - return self.func( - "REGEXP_EXTRACT", expression.this, expression.expression, group, params - ) - - @unsupported_args("culture") - def numbertostr_sql(self, expression: exp.NumberToStr) -> str: - fmt = expression.args.get("format") - if fmt and fmt.is_int: - return self.func("FORMAT", f"'{{:,.{fmt.name}f}}'", expression.this) - - self.unsupported("Only integer formats are supported by NumberToStr") - return self.function_fallback_sql(expression) - - def autoincrementcolumnconstraint_sql(self, _) -> str: - self.unsupported("The AUTOINCREMENT column constraint is not supported by DuckDB") - return "" diff --git a/altimate_packages/sqlglot/dialects/dune.py b/altimate_packages/sqlglot/dialects/dune.py deleted file mode 100644 index be870e2e0..000000000 --- a/altimate_packages/sqlglot/dialects/dune.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - - -from sqlglot import exp -from sqlglot.dialects.trino import Trino - - -class Dune(Trino): - class Tokenizer(Trino.Tokenizer): - HEX_STRINGS = ["0x", ("X'", "'")] - - class Generator(Trino.Generator): - TRANSFORMS = { - **Trino.Generator.TRANSFORMS, - exp.HexString: lambda self, e: f"0x{e.this}", - } diff --git a/altimate_packages/sqlglot/dialects/hive.py b/altimate_packages/sqlglot/dialects/hive.py deleted file mode 100644 index 2bd1a95e2..000000000 --- a/altimate_packages/sqlglot/dialects/hive.py +++ /dev/null @@ -1,787 +0,0 @@ -from __future__ import annotations - -import typing as t -from copy import deepcopy -from functools import partial -from collections import defaultdict - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - DATE_ADD_OR_SUB, - Dialect, - NormalizationStrategy, - approx_count_distinct_sql, - arg_max_or_min_no_count, - datestrtodate_sql, - build_formatted_time, - if_sql, - is_parse_json, - left_to_substring_sql, - max_or_greatest, - min_or_least, - no_ilike_sql, - no_recursive_cte_sql, - no_trycast_sql, - regexp_extract_sql, - regexp_replace_sql, - rename_func, - right_to_substring_sql, - strposition_sql, - struct_extract_sql, - time_format, - timestrtotime_sql, - unit_to_str, - var_map_sql, - sequence_sql, - property_sql, - build_regexp_extract, -) -from sqlglot.transforms import ( - remove_unique_constraints, - ctas_with_tmp_tables_to_create_tmp_view, - preprocess, - move_schema_columns_to_partitioned_by, -) -from sqlglot.helper import seq_get -from sqlglot.tokens import TokenType -from sqlglot.generator import unsupported_args -from sqlglot.optimizer.annotate_types import TypeAnnotator - -# (FuncType, Multiplier) -DATE_DELTA_INTERVAL = { - "YEAR": ("ADD_MONTHS", 12), - "MONTH": ("ADD_MONTHS", 1), - "QUARTER": ("ADD_MONTHS", 3), - "WEEK": ("DATE_ADD", 7), - "DAY": ("DATE_ADD", 1), -} - -TIME_DIFF_FACTOR = { - "MILLISECOND": " * 1000", - "SECOND": "", - "MINUTE": " / 60", - "HOUR": " / 3600", -} - -DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") - -TS_OR_DS_EXPRESSIONS = ( - exp.DateDiff, - exp.Day, - exp.Month, - exp.Year, -) - - -def _add_date_sql(self: Hive.Generator, expression: DATE_ADD_OR_SUB) -> str: - if isinstance(expression, exp.TsOrDsAdd) and not expression.unit: - return self.func("DATE_ADD", expression.this, expression.expression) - - unit = expression.text("unit").upper() - func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) - - if isinstance(expression, exp.DateSub): - multiplier *= -1 - - increment = expression.expression - if isinstance(increment, exp.Literal): - value = increment.to_py() if increment.is_number else int(increment.name) - increment = exp.Literal.number(value * multiplier) - elif multiplier != 1: - increment *= exp.Literal.number(multiplier) - - return self.func(func, expression.this, increment) - - -def _date_diff_sql(self: Hive.Generator, expression: exp.DateDiff | exp.TsOrDsDiff) -> str: - unit = expression.text("unit").upper() - - factor = TIME_DIFF_FACTOR.get(unit) - if factor is not None: - left = self.sql(expression, "this") - right = self.sql(expression, "expression") - sec_diff = f"UNIX_TIMESTAMP({left}) - UNIX_TIMESTAMP({right})" - return f"({sec_diff}){factor}" if factor else sec_diff - - months_between = unit in DIFF_MONTH_SWITCH - sql_func = "MONTHS_BETWEEN" if months_between else "DATEDIFF" - _, multiplier = DATE_DELTA_INTERVAL.get(unit, ("", 1)) - multiplier_sql = f" / {multiplier}" if multiplier > 1 else "" - diff_sql = f"{sql_func}({self.format_args(expression.this, expression.expression)})" - - if months_between or multiplier_sql: - # MONTHS_BETWEEN returns a float, so we need to truncate the fractional part. - # For the same reason, we want to truncate if there's a divisor present. - diff_sql = f"CAST({diff_sql}{multiplier_sql} AS INT)" - - return diff_sql - - -def _json_format_sql(self: Hive.Generator, expression: exp.JSONFormat) -> str: - this = expression.this - - if is_parse_json(this): - if this.this.is_string: - # Since FROM_JSON requires a nested type, we always wrap the json string with - # an array to ensure that "naked" strings like "'a'" will be handled correctly - wrapped_json = exp.Literal.string(f"[{this.this.name}]") - - from_json = self.func( - "FROM_JSON", wrapped_json, self.func("SCHEMA_OF_JSON", wrapped_json) - ) - to_json = self.func("TO_JSON", from_json) - - # This strips the [, ] delimiters of the dummy array printed by TO_JSON - return self.func("REGEXP_EXTRACT", to_json, "'^.(.*).$'", "1") - return self.sql(this) - - return self.func("TO_JSON", this, expression.args.get("options")) - - -@generator.unsupported_args(("expression", "Hive's SORT_ARRAY does not support a comparator.")) -def _array_sort_sql(self: Hive.Generator, expression: exp.ArraySort) -> str: - return self.func("SORT_ARRAY", expression.this) - - -def _str_to_unix_sql(self: Hive.Generator, expression: exp.StrToUnix) -> str: - return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression)) - - -def _unix_to_time_sql(self: Hive.Generator, expression: exp.UnixToTime) -> str: - timestamp = self.sql(expression, "this") - scale = expression.args.get("scale") - if scale in (None, exp.UnixToTime.SECONDS): - return rename_func("FROM_UNIXTIME")(self, expression) - - return f"FROM_UNIXTIME({timestamp} / POW(10, {scale}))" - - -def _str_to_date_sql(self: Hive.Generator, expression: exp.StrToDate) -> str: - this = self.sql(expression, "this") - time_format = self.format_time(expression) - if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): - this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" - return f"CAST({this} AS DATE)" - - -def _str_to_time_sql(self: Hive.Generator, expression: exp.StrToTime) -> str: - this = self.sql(expression, "this") - time_format = self.format_time(expression) - if time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): - this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" - return f"CAST({this} AS TIMESTAMP)" - - -def _to_date_sql(self: Hive.Generator, expression: exp.TsOrDsToDate) -> str: - time_format = self.format_time(expression) - if time_format and time_format not in (Hive.TIME_FORMAT, Hive.DATE_FORMAT): - return self.func("TO_DATE", expression.this, time_format) - - if isinstance(expression.parent, TS_OR_DS_EXPRESSIONS): - return self.sql(expression, "this") - - return self.func("TO_DATE", expression.this) - - -def _build_with_ignore_nulls( - exp_class: t.Type[exp.Expression], -) -> t.Callable[[t.List[exp.Expression]], exp.Expression]: - def _parse(args: t.List[exp.Expression]) -> exp.Expression: - this = exp_class(this=seq_get(args, 0)) - if seq_get(args, 1) == exp.true(): - return exp.IgnoreNulls(this=this) - return this - - return _parse - - -def _build_to_date(args: t.List) -> exp.TsOrDsToDate: - expr = build_formatted_time(exp.TsOrDsToDate, "hive")(args) - expr.set("safe", True) - return expr - - -class Hive(Dialect): - ALIAS_POST_TABLESAMPLE = True - IDENTIFIERS_CAN_START_WITH_DIGIT = True - SUPPORTS_USER_DEFINED_TYPES = False - SAFE_DIVISION = True - ARRAY_AGG_INCLUDES_NULLS = None - REGEXP_EXTRACT_DEFAULT_GROUP = 1 - - # https://spark.apache.org/docs/latest/sql-ref-identifier.html#description - NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE - - ANNOTATORS = { - **Dialect.ANNOTATORS, - exp.If: lambda self, e: self._annotate_by_args(e, "true", "false", promote=True), - exp.Coalesce: lambda self, e: self._annotate_by_args( - e, "this", "expressions", promote=True - ), - } - - # Support only the non-ANSI mode (default for Hive, Spark2, Spark) - COERCES_TO = defaultdict(set, deepcopy(TypeAnnotator.COERCES_TO)) - for target_type in { - *exp.DataType.NUMERIC_TYPES, - *exp.DataType.TEMPORAL_TYPES, - exp.DataType.Type.INTERVAL, - }: - COERCES_TO[target_type] |= exp.DataType.TEXT_TYPES - - TIME_MAPPING = { - "y": "%Y", - "Y": "%Y", - "YYYY": "%Y", - "yyyy": "%Y", - "YY": "%y", - "yy": "%y", - "MMMM": "%B", - "MMM": "%b", - "MM": "%m", - "M": "%-m", - "dd": "%d", - "d": "%-d", - "HH": "%H", - "H": "%-H", - "hh": "%I", - "h": "%-I", - "mm": "%M", - "m": "%-M", - "ss": "%S", - "s": "%-S", - "SSSSSS": "%f", - "a": "%p", - "DD": "%j", - "D": "%-j", - "E": "%a", - "EE": "%a", - "EEE": "%a", - "EEEE": "%A", - "z": "%Z", - "Z": "%z", - } - - DATE_FORMAT = "'yyyy-MM-dd'" - DATEINT_FORMAT = "'yyyyMMdd'" - TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" - - class Tokenizer(tokens.Tokenizer): - QUOTES = ["'", '"'] - IDENTIFIERS = ["`"] - STRING_ESCAPES = ["\\"] - - SINGLE_TOKENS = { - **tokens.Tokenizer.SINGLE_TOKENS, - "$": TokenType.PARAMETER, - } - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - "ADD ARCHIVE": TokenType.COMMAND, - "ADD ARCHIVES": TokenType.COMMAND, - "ADD FILE": TokenType.COMMAND, - "ADD FILES": TokenType.COMMAND, - "ADD JAR": TokenType.COMMAND, - "ADD JARS": TokenType.COMMAND, - "MINUS": TokenType.EXCEPT, - "MSCK REPAIR": TokenType.COMMAND, - "REFRESH": TokenType.REFRESH, - "TIMESTAMP AS OF": TokenType.TIMESTAMP_SNAPSHOT, - "VERSION AS OF": TokenType.VERSION_SNAPSHOT, - "SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, - } - - NUMERIC_LITERALS = { - "L": "BIGINT", - "S": "SMALLINT", - "Y": "TINYINT", - "D": "DOUBLE", - "F": "FLOAT", - "BD": "DECIMAL", - } - - class Parser(parser.Parser): - LOG_DEFAULTS_TO_LN = True - STRICT_CAST = False - VALUES_FOLLOWED_BY_PAREN = False - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "ASCII": exp.Unicode.from_arg_list, - "BASE64": exp.ToBase64.from_arg_list, - "COLLECT_LIST": lambda args: exp.ArrayAgg(this=seq_get(args, 0), nulls_excluded=True), - "COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list, - "DATE_ADD": lambda args: exp.TsOrDsAdd( - this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") - ), - "DATE_FORMAT": lambda args: build_formatted_time(exp.TimeToStr, "hive")( - [ - exp.TimeStrToTime(this=seq_get(args, 0)), - seq_get(args, 1), - ] - ), - "DATE_SUB": lambda args: exp.TsOrDsAdd( - this=seq_get(args, 0), - expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)), - unit=exp.Literal.string("DAY"), - ), - "DATEDIFF": lambda args: exp.DateDiff( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - expression=exp.TsOrDsToDate(this=seq_get(args, 1)), - ), - "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "FIRST": _build_with_ignore_nulls(exp.First), - "FIRST_VALUE": _build_with_ignore_nulls(exp.FirstValue), - "FROM_UNIXTIME": build_formatted_time(exp.UnixToStr, "hive", True), - "GET_JSON_OBJECT": lambda args, dialect: exp.JSONExtractScalar( - this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) - ), - "LAST": _build_with_ignore_nulls(exp.Last), - "LAST_VALUE": _build_with_ignore_nulls(exp.LastValue), - "MAP": parser.build_var_map, - "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), - "PERCENTILE": exp.Quantile.from_arg_list, - "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list, - "REGEXP_EXTRACT": build_regexp_extract(exp.RegexpExtract), - "REGEXP_EXTRACT_ALL": build_regexp_extract(exp.RegexpExtractAll), - "SEQUENCE": exp.GenerateSeries.from_arg_list, - "SIZE": exp.ArraySize.from_arg_list, - "SPLIT": exp.RegexpSplit.from_arg_list, - "STR_TO_MAP": lambda args: exp.StrToMap( - this=seq_get(args, 0), - pair_delim=seq_get(args, 1) or exp.Literal.string(","), - key_value_delim=seq_get(args, 2) or exp.Literal.string(":"), - ), - "TO_DATE": _build_to_date, - "TO_JSON": exp.JSONFormat.from_arg_list, - "TRUNC": exp.TimestampTrunc.from_arg_list, - "UNBASE64": exp.FromBase64.from_arg_list, - "UNIX_TIMESTAMP": lambda args: build_formatted_time(exp.StrToUnix, "hive", True)( - args or [exp.CurrentTimestamp()] - ), - "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), - } - - NO_PAREN_FUNCTION_PARSERS = { - **parser.Parser.NO_PAREN_FUNCTION_PARSERS, - "TRANSFORM": lambda self: self._parse_transform(), - } - - NO_PAREN_FUNCTIONS = parser.Parser.NO_PAREN_FUNCTIONS.copy() - NO_PAREN_FUNCTIONS.pop(TokenType.CURRENT_TIME) - - PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, - "SERDEPROPERTIES": lambda self: exp.SerdeProperties( - expressions=self._parse_wrapped_csv(self._parse_property) - ), - } - - def _parse_transform(self) -> t.Optional[exp.Transform | exp.QueryTransform]: - if not self._match(TokenType.L_PAREN, advance=False): - self._retreat(self._index - 1) - return None - - args = self._parse_wrapped_csv(self._parse_lambda) - row_format_before = self._parse_row_format(match_row=True) - - record_writer = None - if self._match_text_seq("RECORDWRITER"): - record_writer = self._parse_string() - - if not self._match(TokenType.USING): - return exp.Transform.from_arg_list(args) - - command_script = self._parse_string() - - self._match(TokenType.ALIAS) - schema = self._parse_schema() - - row_format_after = self._parse_row_format(match_row=True) - record_reader = None - if self._match_text_seq("RECORDREADER"): - record_reader = self._parse_string() - - return self.expression( - exp.QueryTransform, - expressions=args, - command_script=command_script, - schema=schema, - row_format_before=row_format_before, - record_writer=record_writer, - row_format_after=row_format_after, - record_reader=record_reader, - ) - - def _parse_types( - self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True - ) -> t.Optional[exp.Expression]: - """ - Spark (and most likely Hive) treats casts to CHAR(length) and VARCHAR(length) as casts to - STRING in all contexts except for schema definitions. For example, this is in Spark v3.4.0: - - spark-sql (default)> select cast(1234 as varchar(2)); - 23/06/06 15:51:18 WARN CharVarcharUtils: The Spark cast operator does not support - char/varchar type and simply treats them as string type. Please use string type - directly to avoid confusion. Otherwise, you can set spark.sql.legacy.charVarcharAsString - to true, so that Spark treat them as string type as same as Spark 3.0 and earlier - - 1234 - Time taken: 4.265 seconds, Fetched 1 row(s) - - This shows that Spark doesn't truncate the value into '12', which is inconsistent with - what other dialects (e.g. postgres) do, so we need to drop the length to transpile correctly. - - Reference: https://spark.apache.org/docs/latest/sql-ref-datatypes.html - """ - this = super()._parse_types( - check_func=check_func, schema=schema, allow_identifiers=allow_identifiers - ) - - if this and not schema: - return this.transform( - lambda node: ( - node.replace(exp.DataType.build("text")) - if isinstance(node, exp.DataType) and node.is_type("char", "varchar") - else node - ), - copy=False, - ) - - return this - - def _parse_partition_and_order( - self, - ) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]: - return ( - ( - self._parse_csv(self._parse_assignment) - if self._match_set({TokenType.PARTITION_BY, TokenType.DISTRIBUTE_BY}) - else [] - ), - super()._parse_order(skip_order_token=self._match(TokenType.SORT_BY)), - ) - - def _parse_parameter(self) -> exp.Parameter: - self._match(TokenType.L_BRACE) - this = self._parse_identifier() or self._parse_primary_or_var() - expression = self._match(TokenType.COLON) and ( - self._parse_identifier() or self._parse_primary_or_var() - ) - self._match(TokenType.R_BRACE) - return self.expression(exp.Parameter, this=this, expression=expression) - - def _to_prop_eq(self, expression: exp.Expression, index: int) -> exp.Expression: - if expression.is_star: - return expression - - if isinstance(expression, exp.Column): - key = expression.this - else: - key = exp.to_identifier(f"col{index + 1}") - - return self.expression(exp.PropertyEQ, this=key, expression=expression) - - class Generator(generator.Generator): - LIMIT_FETCH = "LIMIT" - TABLESAMPLE_WITH_METHOD = False - JOIN_HINTS = False - TABLE_HINTS = False - QUERY_HINTS = False - INDEX_ON = "ON TABLE" - EXTRACT_ALLOWS_QUOTES = False - NVL2_SUPPORTED = False - LAST_DAY_SUPPORTS_DATE_PART = False - JSON_PATH_SINGLE_QUOTE_ESCAPE = True - SUPPORTS_TO_NUMBER = False - WITH_PROPERTIES_PREFIX = "TBLPROPERTIES" - PARSE_JSON_NAME: t.Optional[str] = None - PAD_FILL_PATTERN_IS_REQUIRED = True - SUPPORTS_MEDIAN = False - ARRAY_SIZE_NAME = "SIZE" - - EXPRESSIONS_WITHOUT_NESTED_CTES = { - exp.Insert, - exp.Select, - exp.Subquery, - exp.SetOperation, - } - - SUPPORTED_JSON_PATH_PARTS = { - exp.JSONPathKey, - exp.JSONPathRoot, - exp.JSONPathSubscript, - exp.JSONPathWildcard, - } - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.BIT: "BOOLEAN", - exp.DataType.Type.BLOB: "BINARY", - exp.DataType.Type.DATETIME: "TIMESTAMP", - exp.DataType.Type.ROWVERSION: "BINARY", - exp.DataType.Type.TEXT: "STRING", - exp.DataType.Type.TIME: "TIMESTAMP", - exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP", - exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", - exp.DataType.Type.UTINYINT: "SMALLINT", - exp.DataType.Type.VARBINARY: "BINARY", - } - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.Group: transforms.preprocess([transforms.unalias_group]), - exp.Property: property_sql, - exp.AnyValue: rename_func("FIRST"), - exp.ApproxDistinct: approx_count_distinct_sql, - exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), - exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), - exp.ArrayConcat: rename_func("CONCAT"), - exp.ArrayToString: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), - exp.ArraySort: _array_sort_sql, - exp.With: no_recursive_cte_sql, - exp.DateAdd: _add_date_sql, - exp.DateDiff: _date_diff_sql, - exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _add_date_sql, - exp.DateToDi: lambda self, - e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.DATEINT_FORMAT}) AS INT)", - exp.DiToDate: lambda self, - e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.DATEINT_FORMAT})", - exp.FileFormatProperty: lambda self, - e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", - exp.StorageHandlerProperty: lambda self, e: f"STORED BY {self.sql(e, 'this')}", - exp.FromBase64: rename_func("UNBASE64"), - exp.GenerateSeries: sequence_sql, - exp.GenerateDateArray: sequence_sql, - exp.If: if_sql(), - exp.ILike: no_ilike_sql, - exp.IntDiv: lambda self, e: self.binary(e, "DIV"), - exp.IsNan: rename_func("ISNAN"), - exp.JSONExtract: lambda self, e: self.func("GET_JSON_OBJECT", e.this, e.expression), - exp.JSONExtractScalar: lambda self, e: self.func( - "GET_JSON_OBJECT", e.this, e.expression - ), - exp.JSONFormat: _json_format_sql, - exp.Left: left_to_substring_sql, - exp.Map: var_map_sql, - exp.Max: max_or_greatest, - exp.MD5Digest: lambda self, e: self.func("UNHEX", self.func("MD5", e.this)), - exp.Min: min_or_least, - exp.MonthsBetween: lambda self, e: self.func("MONTHS_BETWEEN", e.this, e.expression), - exp.NotNullColumnConstraint: lambda _, e: ( - "" if e.args.get("allow_null") else "NOT NULL" - ), - exp.VarMap: var_map_sql, - exp.Create: preprocess( - [ - remove_unique_constraints, - ctas_with_tmp_tables_to_create_tmp_view, - move_schema_columns_to_partitioned_by, - ] - ), - exp.Quantile: rename_func("PERCENTILE"), - exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), - exp.RegexpExtract: regexp_extract_sql, - exp.RegexpExtractAll: regexp_extract_sql, - exp.RegexpReplace: regexp_replace_sql, - exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), - exp.RegexpSplit: rename_func("SPLIT"), - exp.Right: right_to_substring_sql, - exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), - exp.ArrayUniqueAgg: rename_func("COLLECT_SET"), - exp.Split: lambda self, e: self.func( - "SPLIT", e.this, self.func("CONCAT", "'\\\\Q'", e.expression, "'\\\\E'") - ), - exp.Select: transforms.preprocess( - [ - transforms.eliminate_qualify, - transforms.eliminate_distinct_on, - partial(transforms.unnest_to_explode, unnest_using_arrays_zip=False), - transforms.any_to_exists, - ] - ), - exp.StrPosition: lambda self, e: strposition_sql( - self, e, func_name="LOCATE", supports_position=True - ), - exp.StrToDate: _str_to_date_sql, - exp.StrToTime: _str_to_time_sql, - exp.StrToUnix: _str_to_unix_sql, - exp.StructExtract: struct_extract_sql, - exp.StarMap: rename_func("MAP"), - exp.Table: transforms.preprocess([transforms.unnest_generate_series]), - exp.TimeStrToDate: rename_func("TO_DATE"), - exp.TimeStrToTime: timestrtotime_sql, - exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), - exp.TimestampTrunc: lambda self, e: self.func("TRUNC", e.this, unit_to_str(e)), - exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), - exp.ToBase64: rename_func("BASE64"), - exp.TsOrDiToDi: lambda self, - e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)", - exp.TsOrDsAdd: _add_date_sql, - exp.TsOrDsDiff: _date_diff_sql, - exp.TsOrDsToDate: _to_date_sql, - exp.TryCast: no_trycast_sql, - exp.Unicode: rename_func("ASCII"), - exp.UnixToStr: lambda self, e: self.func( - "FROM_UNIXTIME", e.this, time_format("hive")(self, e) - ), - exp.UnixToTime: _unix_to_time_sql, - exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), - exp.Unnest: rename_func("EXPLODE"), - exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}", - exp.NumberToStr: rename_func("FORMAT_NUMBER"), - exp.National: lambda self, e: self.national_sql(e, prefix=""), - exp.ClusteredColumnConstraint: lambda self, - e: f"({self.expressions(e, 'this', indent=False)})", - exp.NonClusteredColumnConstraint: lambda self, - e: f"({self.expressions(e, 'this', indent=False)})", - exp.NotForReplicationColumnConstraint: lambda *_: "", - exp.OnProperty: lambda *_: "", - exp.PartitionedByBucket: lambda self, e: self.func("BUCKET", e.expression, e.this), - exp.PartitionByTruncate: lambda self, e: self.func("TRUNCATE", e.expression, e.this), - exp.PrimaryKeyColumnConstraint: lambda *_: "PRIMARY KEY", - exp.WeekOfYear: rename_func("WEEKOFYEAR"), - exp.DayOfMonth: rename_func("DAYOFMONTH"), - exp.DayOfWeek: rename_func("DAYOFWEEK"), - exp.Levenshtein: unsupported_args("ins_cost", "del_cost", "sub_cost", "max_dist")( - rename_func("LEVENSHTEIN") - ), - } - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA, - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, - exp.WithDataProperty: exp.Properties.Location.UNSUPPORTED, - } - - def unnest_sql(self, expression: exp.Unnest) -> str: - return rename_func("EXPLODE")(self, expression) - - def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str: - if isinstance(expression.this, exp.JSONPathWildcard): - self.unsupported("Unsupported wildcard in JSONPathKey expression") - return "" - - return super()._jsonpathkey_sql(expression) - - def parameter_sql(self, expression: exp.Parameter) -> str: - this = self.sql(expression, "this") - expression_sql = self.sql(expression, "expression") - - parent = expression.parent - this = f"{this}:{expression_sql}" if expression_sql else this - - if isinstance(parent, exp.EQ) and isinstance(parent.parent, exp.SetItem): - # We need to produce SET key = value instead of SET ${key} = value - return this - - return f"${{{this}}}" - - def schema_sql(self, expression: exp.Schema) -> str: - for ordered in expression.find_all(exp.Ordered): - if ordered.args.get("desc") is False: - ordered.set("desc", None) - - return super().schema_sql(expression) - - def constraint_sql(self, expression: exp.Constraint) -> str: - for prop in list(expression.find_all(exp.Properties)): - prop.pop() - - this = self.sql(expression, "this") - expressions = self.expressions(expression, sep=" ", flat=True) - return f"CONSTRAINT {this} {expressions}" - - def rowformatserdeproperty_sql(self, expression: exp.RowFormatSerdeProperty) -> str: - serde_props = self.sql(expression, "serde_properties") - serde_props = f" {serde_props}" if serde_props else "" - return f"ROW FORMAT SERDE {self.sql(expression, 'this')}{serde_props}" - - def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: - return self.func( - "COLLECT_LIST", - expression.this.this if isinstance(expression.this, exp.Order) else expression.this, - ) - - def datatype_sql(self, expression: exp.DataType) -> str: - if expression.this in self.PARAMETERIZABLE_TEXT_TYPES and ( - not expression.expressions or expression.expressions[0].name == "MAX" - ): - expression = exp.DataType.build("text") - elif expression.is_type(exp.DataType.Type.TEXT) and expression.expressions: - expression.set("this", exp.DataType.Type.VARCHAR) - elif expression.this in exp.DataType.TEMPORAL_TYPES: - expression = exp.DataType.build(expression.this) - elif expression.is_type("float"): - size_expression = expression.find(exp.DataTypeParam) - if size_expression: - size = int(size_expression.name) - expression = ( - exp.DataType.build("float") if size <= 32 else exp.DataType.build("double") - ) - - return super().datatype_sql(expression) - - def version_sql(self, expression: exp.Version) -> str: - sql = super().version_sql(expression) - return sql.replace("FOR ", "", 1) - - def struct_sql(self, expression: exp.Struct) -> str: - values = [] - - for i, e in enumerate(expression.expressions): - if isinstance(e, exp.PropertyEQ): - self.unsupported("Hive does not support named structs.") - values.append(e.expression) - else: - values.append(e) - - return self.func("STRUCT", *values) - - def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: - return super().columndef_sql( - expression, - sep=( - ": " - if isinstance(expression.parent, exp.DataType) - and expression.parent.is_type("struct") - else sep - ), - ) - - def alterset_sql(self, expression: exp.AlterSet) -> str: - exprs = self.expressions(expression, flat=True) - exprs = f" {exprs}" if exprs else "" - location = self.sql(expression, "location") - location = f" LOCATION {location}" if location else "" - file_format = self.expressions(expression, key="file_format", flat=True, sep=" ") - file_format = f" FILEFORMAT {file_format}" if file_format else "" - serde = self.sql(expression, "serde") - serde = f" SERDE {serde}" if serde else "" - tags = self.expressions(expression, key="tag", flat=True, sep="") - tags = f" TAGS {tags}" if tags else "" - - return f"SET{serde}{exprs}{location}{file_format}{tags}" - - def serdeproperties_sql(self, expression: exp.SerdeProperties) -> str: - prefix = "WITH " if expression.args.get("with") else "" - exprs = self.expressions(expression, flat=True) - - return f"{prefix}SERDEPROPERTIES ({exprs})" - - def exists_sql(self, expression: exp.Exists) -> str: - if expression.expression: - return self.function_fallback_sql(expression) - - return super().exists_sql(expression) - - def timetostr_sql(self, expression: exp.TimeToStr) -> str: - this = expression.this - if isinstance(this, exp.TimeStrToTime): - this = this.this - - return self.func("DATE_FORMAT", this, self.format_time(expression)) diff --git a/altimate_packages/sqlglot/dialects/materialize.py b/altimate_packages/sqlglot/dialects/materialize.py deleted file mode 100644 index 4534c21ba..000000000 --- a/altimate_packages/sqlglot/dialects/materialize.py +++ /dev/null @@ -1,94 +0,0 @@ -from __future__ import annotations - -from sqlglot import exp -from sqlglot.helper import seq_get -from sqlglot.dialects.postgres import Postgres - -from sqlglot.tokens import TokenType -from sqlglot.transforms import ( - remove_unique_constraints, - ctas_with_tmp_tables_to_create_tmp_view, - preprocess, -) -import typing as t - - -class Materialize(Postgres): - class Parser(Postgres.Parser): - NO_PAREN_FUNCTION_PARSERS = { - **Postgres.Parser.NO_PAREN_FUNCTION_PARSERS, - "MAP": lambda self: self._parse_map(), - } - - LAMBDAS = { - **Postgres.Parser.LAMBDAS, - TokenType.FARROW: lambda self, expressions: self.expression( - exp.Kwarg, this=seq_get(expressions, 0), expression=self._parse_assignment() - ), - } - - def _parse_lambda_arg(self) -> t.Optional[exp.Expression]: - return self._parse_field() - - def _parse_map(self) -> exp.ToMap: - if self._match(TokenType.L_PAREN): - to_map = self.expression(exp.ToMap, this=self._parse_select()) - self._match_r_paren() - return to_map - - if not self._match(TokenType.L_BRACKET): - self.raise_error("Expecting [") - - entries = [ - exp.PropertyEQ(this=e.this, expression=e.expression) - for e in self._parse_csv(self._parse_lambda) - ] - - if not self._match(TokenType.R_BRACKET): - self.raise_error("Expecting ]") - - return self.expression(exp.ToMap, this=self.expression(exp.Struct, expressions=entries)) - - class Generator(Postgres.Generator): - SUPPORTS_CREATE_TABLE_LIKE = False - - TRANSFORMS = { - **Postgres.Generator.TRANSFORMS, - exp.AutoIncrementColumnConstraint: lambda self, e: "", - exp.Create: preprocess( - [ - remove_unique_constraints, - ctas_with_tmp_tables_to_create_tmp_view, - ] - ), - exp.GeneratedAsIdentityColumnConstraint: lambda self, e: "", - exp.OnConflict: lambda self, e: "", - exp.PrimaryKeyColumnConstraint: lambda self, e: "", - } - TRANSFORMS.pop(exp.ToMap) - - def propertyeq_sql(self, expression: exp.PropertyEQ) -> str: - return self.binary(expression, "=>") - - def datatype_sql(self, expression: exp.DataType) -> str: - if expression.is_type(exp.DataType.Type.LIST): - if expression.expressions: - return f"{self.expressions(expression, flat=True)} LIST" - return "LIST" - - if expression.is_type(exp.DataType.Type.MAP) and len(expression.expressions) == 2: - key, value = expression.expressions - return f"MAP[{self.sql(key)} => {self.sql(value)}]" - - return super().datatype_sql(expression) - - def list_sql(self, expression: exp.List) -> str: - if isinstance(seq_get(expression.expressions, 0), exp.Select): - return self.func("LIST", seq_get(expression.expressions, 0)) - - return f"{self.normalize_func('LIST')}[{self.expressions(expression, flat=True)}]" - - def tomap_sql(self, expression: exp.ToMap) -> str: - if isinstance(expression.this, exp.Select): - return self.func("MAP", expression.this) - return f"{self.normalize_func('MAP')}[{self.expressions(expression.this)}]" diff --git a/altimate_packages/sqlglot/dialects/mysql.py b/altimate_packages/sqlglot/dialects/mysql.py deleted file mode 100644 index a54456fde..000000000 --- a/altimate_packages/sqlglot/dialects/mysql.py +++ /dev/null @@ -1,1324 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - Dialect, - NormalizationStrategy, - arrow_json_extract_sql, - date_add_interval_sql, - datestrtodate_sql, - build_formatted_time, - isnull_to_is_null, - length_or_char_length_sql, - max_or_greatest, - min_or_least, - no_ilike_sql, - no_paren_current_date_sql, - no_pivot_sql, - no_tablesample_sql, - no_trycast_sql, - build_date_delta, - build_date_delta_with_interval, - rename_func, - strposition_sql, - unit_to_var, - trim_sql, - timestrtotime_sql, -) -from sqlglot.generator import unsupported_args -from sqlglot.helper import seq_get -from sqlglot.tokens import TokenType - - -def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[MySQL.Parser], exp.Show]: - def _parse(self: MySQL.Parser) -> exp.Show: - return self._parse_show_mysql(*args, **kwargs) - - return _parse - - -def _date_trunc_sql(self: MySQL.Generator, expression: exp.DateTrunc) -> str: - expr = self.sql(expression, "this") - unit = expression.text("unit").upper() - - if unit == "WEEK": - concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')" - date_format = "%Y %u %w" - elif unit == "MONTH": - concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')" - date_format = "%Y %c %e" - elif unit == "QUARTER": - concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')" - date_format = "%Y %c %e" - elif unit == "YEAR": - concat = f"CONCAT(YEAR({expr}), ' 1 1')" - date_format = "%Y %c %e" - else: - if unit != "DAY": - self.unsupported(f"Unexpected interval unit: {unit}") - return self.func("DATE", expr) - - return self.func("STR_TO_DATE", concat, f"'{date_format}'") - - -# All specifiers for time parts (as opposed to date parts) -# https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-format -TIME_SPECIFIERS = {"f", "H", "h", "I", "i", "k", "l", "p", "r", "S", "s", "T"} - - -def _has_time_specifier(date_format: str) -> bool: - i = 0 - length = len(date_format) - - while i < length: - if date_format[i] == "%": - i += 1 - if i < length and date_format[i] in TIME_SPECIFIERS: - return True - i += 1 - return False - - -def _str_to_date(args: t.List) -> exp.StrToDate | exp.StrToTime: - mysql_date_format = seq_get(args, 1) - date_format = MySQL.format_time(mysql_date_format) - this = seq_get(args, 0) - - if mysql_date_format and _has_time_specifier(mysql_date_format.name): - return exp.StrToTime(this=this, format=date_format) - - return exp.StrToDate(this=this, format=date_format) - - -def _str_to_date_sql( - self: MySQL.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate -) -> str: - return self.func("STR_TO_DATE", expression.this, self.format_time(expression)) - - -def _unix_to_time_sql(self: MySQL.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = expression.this - - if scale in (None, exp.UnixToTime.SECONDS): - return self.func("FROM_UNIXTIME", timestamp, self.format_time(expression)) - - return self.func( - "FROM_UNIXTIME", - exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), - self.format_time(expression), - ) - - -def date_add_sql( - kind: str, -) -> t.Callable[[generator.Generator, exp.Expression], str]: - def func(self: generator.Generator, expression: exp.Expression) -> str: - return self.func( - f"DATE_{kind}", - expression.this, - exp.Interval(this=expression.expression, unit=unit_to_var(expression)), - ) - - return func - - -def _ts_or_ds_to_date_sql(self: MySQL.Generator, expression: exp.TsOrDsToDate) -> str: - time_format = expression.args.get("format") - return _str_to_date_sql(self, expression) if time_format else self.func("DATE", expression.this) - - -def _remove_ts_or_ds_to_date( - to_sql: t.Optional[t.Callable[[MySQL.Generator, exp.Expression], str]] = None, - args: t.Tuple[str, ...] = ("this",), -) -> t.Callable[[MySQL.Generator, exp.Func], str]: - def func(self: MySQL.Generator, expression: exp.Func) -> str: - for arg_key in args: - arg = expression.args.get(arg_key) - if isinstance(arg, exp.TsOrDsToDate) and not arg.args.get("format"): - expression.set(arg_key, arg.this) - - return to_sql(self, expression) if to_sql else self.function_fallback_sql(expression) - - return func - - -class MySQL(Dialect): - PROMOTE_TO_INFERRED_DATETIME_TYPE = True - - # https://dev.mysql.com/doc/refman/8.0/en/identifiers.html - IDENTIFIERS_CAN_START_WITH_DIGIT = True - - # We default to treating all identifiers as case-sensitive, since it matches MySQL's - # behavior on Linux systems. For MacOS and Windows systems, one can override this - # setting by specifying `dialect="mysql, normalization_strategy = lowercase"`. - # - # See also https://dev.mysql.com/doc/refman/8.2/en/identifier-case-sensitivity.html - NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_SENSITIVE - - TIME_FORMAT = "'%Y-%m-%d %T'" - DPIPE_IS_STRING_CONCAT = False - SUPPORTS_USER_DEFINED_TYPES = False - SUPPORTS_SEMI_ANTI_JOIN = False - SAFE_DIVISION = True - - # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions - TIME_MAPPING = { - "%M": "%B", - "%c": "%-m", - "%e": "%-d", - "%h": "%I", - "%i": "%M", - "%s": "%S", - "%u": "%W", - "%k": "%-H", - "%l": "%-I", - "%T": "%H:%M:%S", - "%W": "%A", - } - - class Tokenizer(tokens.Tokenizer): - QUOTES = ["'", '"'] - COMMENTS = ["--", "#", ("/*", "*/")] - IDENTIFIERS = ["`"] - STRING_ESCAPES = ["'", '"', "\\"] - BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")] - HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")] - - NESTED_COMMENTS = False - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - "CHARSET": TokenType.CHARACTER_SET, - # The DESCRIBE and EXPLAIN statements are synonyms. - # https://dev.mysql.com/doc/refman/8.4/en/explain.html - "BLOB": TokenType.BLOB, - "EXPLAIN": TokenType.DESCRIBE, - "FORCE": TokenType.FORCE, - "IGNORE": TokenType.IGNORE, - "KEY": TokenType.KEY, - "LOCK TABLES": TokenType.COMMAND, - "LONGBLOB": TokenType.LONGBLOB, - "LONGTEXT": TokenType.LONGTEXT, - "MEDIUMBLOB": TokenType.MEDIUMBLOB, - "TINYBLOB": TokenType.TINYBLOB, - "TINYTEXT": TokenType.TINYTEXT, - "MEDIUMTEXT": TokenType.MEDIUMTEXT, - "MEDIUMINT": TokenType.MEDIUMINT, - "MEMBER OF": TokenType.MEMBER_OF, - "SEPARATOR": TokenType.SEPARATOR, - "SERIAL": TokenType.SERIAL, - "START": TokenType.BEGIN, - "SIGNED": TokenType.BIGINT, - "SIGNED INTEGER": TokenType.BIGINT, - "TIMESTAMP": TokenType.TIMESTAMPTZ, - "UNLOCK TABLES": TokenType.COMMAND, - "UNSIGNED": TokenType.UBIGINT, - "UNSIGNED INTEGER": TokenType.UBIGINT, - "YEAR": TokenType.YEAR, - "_ARMSCII8": TokenType.INTRODUCER, - "_ASCII": TokenType.INTRODUCER, - "_BIG5": TokenType.INTRODUCER, - "_BINARY": TokenType.INTRODUCER, - "_CP1250": TokenType.INTRODUCER, - "_CP1251": TokenType.INTRODUCER, - "_CP1256": TokenType.INTRODUCER, - "_CP1257": TokenType.INTRODUCER, - "_CP850": TokenType.INTRODUCER, - "_CP852": TokenType.INTRODUCER, - "_CP866": TokenType.INTRODUCER, - "_CP932": TokenType.INTRODUCER, - "_DEC8": TokenType.INTRODUCER, - "_EUCJPMS": TokenType.INTRODUCER, - "_EUCKR": TokenType.INTRODUCER, - "_GB18030": TokenType.INTRODUCER, - "_GB2312": TokenType.INTRODUCER, - "_GBK": TokenType.INTRODUCER, - "_GEOSTD8": TokenType.INTRODUCER, - "_GREEK": TokenType.INTRODUCER, - "_HEBREW": TokenType.INTRODUCER, - "_HP8": TokenType.INTRODUCER, - "_KEYBCS2": TokenType.INTRODUCER, - "_KOI8R": TokenType.INTRODUCER, - "_KOI8U": TokenType.INTRODUCER, - "_LATIN1": TokenType.INTRODUCER, - "_LATIN2": TokenType.INTRODUCER, - "_LATIN5": TokenType.INTRODUCER, - "_LATIN7": TokenType.INTRODUCER, - "_MACCE": TokenType.INTRODUCER, - "_MACROMAN": TokenType.INTRODUCER, - "_SJIS": TokenType.INTRODUCER, - "_SWE7": TokenType.INTRODUCER, - "_TIS620": TokenType.INTRODUCER, - "_UCS2": TokenType.INTRODUCER, - "_UJIS": TokenType.INTRODUCER, - # https://dev.mysql.com/doc/refman/8.0/en/string-literals.html - "_UTF8": TokenType.INTRODUCER, - "_UTF16": TokenType.INTRODUCER, - "_UTF16LE": TokenType.INTRODUCER, - "_UTF32": TokenType.INTRODUCER, - "_UTF8MB3": TokenType.INTRODUCER, - "_UTF8MB4": TokenType.INTRODUCER, - "@@": TokenType.SESSION_PARAMETER, - } - - COMMANDS = {*tokens.Tokenizer.COMMANDS, TokenType.REPLACE} - {TokenType.SHOW} - - class Parser(parser.Parser): - FUNC_TOKENS = { - *parser.Parser.FUNC_TOKENS, - TokenType.DATABASE, - TokenType.SCHEMA, - TokenType.VALUES, - } - - CONJUNCTION = { - **parser.Parser.CONJUNCTION, - TokenType.DAMP: exp.And, - TokenType.XOR: exp.Xor, - } - - DISJUNCTION = { - **parser.Parser.DISJUNCTION, - TokenType.DPIPE: exp.Or, - } - - TABLE_ALIAS_TOKENS = ( - parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS - ) - - RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, - TokenType.MEMBER_OF: lambda self, this: self.expression( - exp.JSONArrayContains, - this=this, - expression=self._parse_wrapped(self._parse_expression), - ), - } - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "CONVERT_TZ": lambda args: exp.ConvertTimezone( - source_tz=seq_get(args, 1), target_tz=seq_get(args, 2), timestamp=seq_get(args, 0) - ), - "CURDATE": exp.CurrentDate.from_arg_list, - "DATE": lambda args: exp.TsOrDsToDate(this=seq_get(args, 0)), - "DATE_ADD": build_date_delta_with_interval(exp.DateAdd), - "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "mysql"), - "DATE_SUB": build_date_delta_with_interval(exp.DateSub), - "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "FORMAT": exp.NumberToStr.from_arg_list, - "FROM_UNIXTIME": build_formatted_time(exp.UnixToTime, "mysql"), - "ISNULL": isnull_to_is_null, - "LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True), - "MAKETIME": exp.TimeFromParts.from_arg_list, - "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "MONTHNAME": lambda args: exp.TimeToStr( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - format=exp.Literal.string("%B"), - ), - "SCHEMA": exp.CurrentSchema.from_arg_list, - "DATABASE": exp.CurrentSchema.from_arg_list, - "STR_TO_DATE": _str_to_date, - "TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff), - "TO_DAYS": lambda args: exp.paren( - exp.DateDiff( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - expression=exp.TsOrDsToDate(this=exp.Literal.string("0000-01-01")), - unit=exp.var("DAY"), - ) - + 1 - ), - "WEEK": lambda args: exp.Week( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), mode=seq_get(args, 1) - ), - "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - } - - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, - "CHAR": lambda self: self.expression( - exp.Chr, - expressions=self._parse_csv(self._parse_assignment), - charset=self._match(TokenType.USING) and self._parse_var(), - ), - "GROUP_CONCAT": lambda self: self._parse_group_concat(), - # https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_values - "VALUES": lambda self: self.expression( - exp.Anonymous, this="VALUES", expressions=[self._parse_id_var()] - ), - "JSON_VALUE": lambda self: self._parse_json_value(), - } - - STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, - TokenType.SHOW: lambda self: self._parse_show(), - } - - SHOW_PARSERS = { - "BINARY LOGS": _show_parser("BINARY LOGS"), - "MASTER LOGS": _show_parser("BINARY LOGS"), - "BINLOG EVENTS": _show_parser("BINLOG EVENTS"), - "CHARACTER SET": _show_parser("CHARACTER SET"), - "CHARSET": _show_parser("CHARACTER SET"), - "COLLATION": _show_parser("COLLATION"), - "FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True), - "COLUMNS": _show_parser("COLUMNS", target="FROM"), - "CREATE DATABASE": _show_parser("CREATE DATABASE", target=True), - "CREATE EVENT": _show_parser("CREATE EVENT", target=True), - "CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True), - "CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True), - "CREATE TABLE": _show_parser("CREATE TABLE", target=True), - "CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True), - "CREATE VIEW": _show_parser("CREATE VIEW", target=True), - "DATABASES": _show_parser("DATABASES"), - "SCHEMAS": _show_parser("DATABASES"), - "ENGINE": _show_parser("ENGINE", target=True), - "STORAGE ENGINES": _show_parser("ENGINES"), - "ENGINES": _show_parser("ENGINES"), - "ERRORS": _show_parser("ERRORS"), - "EVENTS": _show_parser("EVENTS"), - "FUNCTION CODE": _show_parser("FUNCTION CODE", target=True), - "FUNCTION STATUS": _show_parser("FUNCTION STATUS"), - "GRANTS": _show_parser("GRANTS", target="FOR"), - "INDEX": _show_parser("INDEX", target="FROM"), - "MASTER STATUS": _show_parser("MASTER STATUS"), - "OPEN TABLES": _show_parser("OPEN TABLES"), - "PLUGINS": _show_parser("PLUGINS"), - "PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True), - "PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"), - "PRIVILEGES": _show_parser("PRIVILEGES"), - "FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True), - "PROCESSLIST": _show_parser("PROCESSLIST"), - "PROFILE": _show_parser("PROFILE"), - "PROFILES": _show_parser("PROFILES"), - "RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"), - "REPLICAS": _show_parser("REPLICAS"), - "SLAVE HOSTS": _show_parser("REPLICAS"), - "REPLICA STATUS": _show_parser("REPLICA STATUS"), - "SLAVE STATUS": _show_parser("REPLICA STATUS"), - "GLOBAL STATUS": _show_parser("STATUS", global_=True), - "SESSION STATUS": _show_parser("STATUS"), - "STATUS": _show_parser("STATUS"), - "TABLE STATUS": _show_parser("TABLE STATUS"), - "FULL TABLES": _show_parser("TABLES", full=True), - "TABLES": _show_parser("TABLES"), - "TRIGGERS": _show_parser("TRIGGERS"), - "GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True), - "SESSION VARIABLES": _show_parser("VARIABLES"), - "VARIABLES": _show_parser("VARIABLES"), - "WARNINGS": _show_parser("WARNINGS"), - } - - PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, - "LOCK": lambda self: self._parse_property_assignment(exp.LockProperty), - } - - SET_PARSERS = { - **parser.Parser.SET_PARSERS, - "PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"), - "PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"), - "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"), - "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"), - "NAMES": lambda self: self._parse_set_item_names(), - } - - CONSTRAINT_PARSERS = { - **parser.Parser.CONSTRAINT_PARSERS, - "FULLTEXT": lambda self: self._parse_index_constraint(kind="FULLTEXT"), - "INDEX": lambda self: self._parse_index_constraint(), - "KEY": lambda self: self._parse_index_constraint(), - "SPATIAL": lambda self: self._parse_index_constraint(kind="SPATIAL"), - } - - ALTER_PARSERS = { - **parser.Parser.ALTER_PARSERS, - "MODIFY": lambda self: self._parse_alter_table_alter(), - } - - ALTER_ALTER_PARSERS = { - **parser.Parser.ALTER_ALTER_PARSERS, - "INDEX": lambda self: self._parse_alter_table_alter_index(), - } - - SCHEMA_UNNAMED_CONSTRAINTS = { - *parser.Parser.SCHEMA_UNNAMED_CONSTRAINTS, - "FULLTEXT", - "INDEX", - "KEY", - "SPATIAL", - } - - PROFILE_TYPES: parser.OPTIONS_TYPE = { - **dict.fromkeys(("ALL", "CPU", "IPC", "MEMORY", "SOURCE", "SWAPS"), tuple()), - "BLOCK": ("IO",), - "CONTEXT": ("SWITCHES",), - "PAGE": ("FAULTS",), - } - - TYPE_TOKENS = { - *parser.Parser.TYPE_TOKENS, - TokenType.SET, - } - - ENUM_TYPE_TOKENS = { - *parser.Parser.ENUM_TYPE_TOKENS, - TokenType.SET, - } - - # SELECT [ ALL | DISTINCT | DISTINCTROW ] [ ] - OPERATION_MODIFIERS = { - "HIGH_PRIORITY", - "STRAIGHT_JOIN", - "SQL_SMALL_RESULT", - "SQL_BIG_RESULT", - "SQL_BUFFER_RESULT", - "SQL_NO_CACHE", - "SQL_CALC_FOUND_ROWS", - } - - LOG_DEFAULTS_TO_LN = True - STRING_ALIASES = True - VALUES_FOLLOWED_BY_PAREN = False - SUPPORTS_PARTITION_SELECTION = True - - def _parse_generated_as_identity( - self, - ) -> ( - exp.GeneratedAsIdentityColumnConstraint - | exp.ComputedColumnConstraint - | exp.GeneratedAsRowColumnConstraint - ): - this = super()._parse_generated_as_identity() - - if self._match_texts(("STORED", "VIRTUAL")): - persisted = self._prev.text.upper() == "STORED" - - if isinstance(this, exp.ComputedColumnConstraint): - this.set("persisted", persisted) - elif isinstance(this, exp.GeneratedAsIdentityColumnConstraint): - this = self.expression( - exp.ComputedColumnConstraint, this=this.expression, persisted=persisted - ) - - return this - - def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: - this = self._parse_id_var() - if not self._match(TokenType.L_PAREN): - return this - - expression = self._parse_number() - self._match_r_paren() - return self.expression(exp.ColumnPrefix, this=this, expression=expression) - - def _parse_index_constraint( - self, kind: t.Optional[str] = None - ) -> exp.IndexColumnConstraint: - if kind: - self._match_texts(("INDEX", "KEY")) - - this = self._parse_id_var(any_token=False) - index_type = self._match(TokenType.USING) and self._advance_any() and self._prev.text - expressions = self._parse_wrapped_csv(self._parse_ordered) - - options = [] - while True: - if self._match_text_seq("KEY_BLOCK_SIZE"): - self._match(TokenType.EQ) - opt = exp.IndexConstraintOption(key_block_size=self._parse_number()) - elif self._match(TokenType.USING): - opt = exp.IndexConstraintOption(using=self._advance_any() and self._prev.text) - elif self._match_text_seq("WITH", "PARSER"): - opt = exp.IndexConstraintOption(parser=self._parse_var(any_token=True)) - elif self._match(TokenType.COMMENT): - opt = exp.IndexConstraintOption(comment=self._parse_string()) - elif self._match_text_seq("VISIBLE"): - opt = exp.IndexConstraintOption(visible=True) - elif self._match_text_seq("INVISIBLE"): - opt = exp.IndexConstraintOption(visible=False) - elif self._match_text_seq("ENGINE_ATTRIBUTE"): - self._match(TokenType.EQ) - opt = exp.IndexConstraintOption(engine_attr=self._parse_string()) - elif self._match_text_seq("SECONDARY_ENGINE_ATTRIBUTE"): - self._match(TokenType.EQ) - opt = exp.IndexConstraintOption(secondary_engine_attr=self._parse_string()) - else: - opt = None - - if not opt: - break - - options.append(opt) - - return self.expression( - exp.IndexColumnConstraint, - this=this, - expressions=expressions, - kind=kind, - index_type=index_type, - options=options, - ) - - def _parse_show_mysql( - self, - this: str, - target: bool | str = False, - full: t.Optional[bool] = None, - global_: t.Optional[bool] = None, - ) -> exp.Show: - if target: - if isinstance(target, str): - self._match_text_seq(target) - target_id = self._parse_id_var() - else: - target_id = None - - log = self._parse_string() if self._match_text_seq("IN") else None - - if this in ("BINLOG EVENTS", "RELAYLOG EVENTS"): - position = self._parse_number() if self._match_text_seq("FROM") else None - db = None - else: - position = None - db = None - - if self._match(TokenType.FROM): - db = self._parse_id_var() - elif self._match(TokenType.DOT): - db = target_id - target_id = self._parse_id_var() - - channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None - - like = self._parse_string() if self._match_text_seq("LIKE") else None - where = self._parse_where() - - if this == "PROFILE": - types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES)) - query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None - offset = self._parse_number() if self._match_text_seq("OFFSET") else None - limit = self._parse_number() if self._match_text_seq("LIMIT") else None - else: - types, query = None, None - offset, limit = self._parse_oldstyle_limit() - - mutex = True if self._match_text_seq("MUTEX") else None - mutex = False if self._match_text_seq("STATUS") else mutex - - return self.expression( - exp.Show, - this=this, - target=target_id, - full=full, - log=log, - position=position, - db=db, - channel=channel, - like=like, - where=where, - types=types, - query=query, - offset=offset, - limit=limit, - mutex=mutex, - **{"global": global_}, # type: ignore - ) - - def _parse_oldstyle_limit( - self, - ) -> t.Tuple[t.Optional[exp.Expression], t.Optional[exp.Expression]]: - limit = None - offset = None - if self._match_text_seq("LIMIT"): - parts = self._parse_csv(self._parse_number) - if len(parts) == 1: - limit = parts[0] - elif len(parts) == 2: - limit = parts[1] - offset = parts[0] - - return offset, limit - - def _parse_set_item_charset(self, kind: str) -> exp.Expression: - this = self._parse_string() or self._parse_unquoted_field() - return self.expression(exp.SetItem, this=this, kind=kind) - - def _parse_set_item_names(self) -> exp.Expression: - charset = self._parse_string() or self._parse_unquoted_field() - if self._match_text_seq("COLLATE"): - collate = self._parse_string() or self._parse_unquoted_field() - else: - collate = None - - return self.expression(exp.SetItem, this=charset, collate=collate, kind="NAMES") - - def _parse_type( - self, parse_interval: bool = True, fallback_to_identifier: bool = False - ) -> t.Optional[exp.Expression]: - # mysql binary is special and can work anywhere, even in order by operations - # it operates like a no paren func - if self._match(TokenType.BINARY, advance=False): - data_type = self._parse_types(check_func=True, allow_identifiers=False) - - if isinstance(data_type, exp.DataType): - return self.expression(exp.Cast, this=self._parse_column(), to=data_type) - - return super()._parse_type( - parse_interval=parse_interval, fallback_to_identifier=fallback_to_identifier - ) - - def _parse_group_concat(self) -> t.Optional[exp.Expression]: - def concat_exprs( - node: t.Optional[exp.Expression], exprs: t.List[exp.Expression] - ) -> exp.Expression: - if isinstance(node, exp.Distinct) and len(node.expressions) > 1: - concat_exprs = [ - self.expression(exp.Concat, expressions=node.expressions, safe=True) - ] - node.set("expressions", concat_exprs) - return node - if len(exprs) == 1: - return exprs[0] - return self.expression(exp.Concat, expressions=args, safe=True) - - args = self._parse_csv(self._parse_lambda) - - if args: - order = args[-1] if isinstance(args[-1], exp.Order) else None - - if order: - # Order By is the last (or only) expression in the list and has consumed the 'expr' before it, - # remove 'expr' from exp.Order and add it back to args - args[-1] = order.this - order.set("this", concat_exprs(order.this, args)) - - this = order or concat_exprs(args[0], args) - else: - this = None - - separator = self._parse_field() if self._match(TokenType.SEPARATOR) else None - - return self.expression(exp.GroupConcat, this=this, separator=separator) - - def _parse_json_value(self) -> exp.JSONValue: - this = self._parse_bitwise() - self._match(TokenType.COMMA) - path = self._parse_bitwise() - - returning = self._match(TokenType.RETURNING) and self._parse_type() - - return self.expression( - exp.JSONValue, - this=this, - path=self.dialect.to_json_path(path), - returning=returning, - on_condition=self._parse_on_condition(), - ) - - def _parse_alter_table_alter_index(self) -> exp.AlterIndex: - index = self._parse_field(any_token=True) - - if self._match_text_seq("VISIBLE"): - visible = True - elif self._match_text_seq("INVISIBLE"): - visible = False - else: - visible = None - - return self.expression(exp.AlterIndex, this=index, visible=visible) - - class Generator(generator.Generator): - INTERVAL_ALLOWS_PLURAL_FORM = False - LOCKING_READS_SUPPORTED = True - NULL_ORDERING_SUPPORTED = None - JOIN_HINTS = False - TABLE_HINTS = True - DUPLICATE_KEY_UPDATE_WITH_SET = False - QUERY_HINT_SEP = " " - VALUES_AS_TABLE = False - NVL2_SUPPORTED = False - LAST_DAY_SUPPORTS_DATE_PART = False - JSON_TYPE_REQUIRED_FOR_EXTRACTION = True - JSON_PATH_BRACKETED_KEY_SUPPORTED = False - JSON_KEY_VALUE_PAIR_SEP = "," - SUPPORTS_TO_NUMBER = False - PARSE_JSON_NAME: t.Optional[str] = None - PAD_FILL_PATTERN_IS_REQUIRED = True - WRAP_DERIVED_VALUES = False - VARCHAR_REQUIRES_SIZE = True - SUPPORTS_MEDIAN = False - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.ArrayAgg: rename_func("GROUP_CONCAT"), - exp.CurrentDate: no_paren_current_date_sql, - exp.DateDiff: _remove_ts_or_ds_to_date( - lambda self, e: self.func("DATEDIFF", e.this, e.expression), ("this", "expression") - ), - exp.DateAdd: _remove_ts_or_ds_to_date(date_add_sql("ADD")), - exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _remove_ts_or_ds_to_date(date_add_sql("SUB")), - exp.DateTrunc: _date_trunc_sql, - exp.Day: _remove_ts_or_ds_to_date(), - exp.DayOfMonth: _remove_ts_or_ds_to_date(rename_func("DAYOFMONTH")), - exp.DayOfWeek: _remove_ts_or_ds_to_date(rename_func("DAYOFWEEK")), - exp.DayOfYear: _remove_ts_or_ds_to_date(rename_func("DAYOFYEAR")), - exp.GroupConcat: lambda self, - e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""", - exp.ILike: no_ilike_sql, - exp.JSONExtractScalar: arrow_json_extract_sql, - exp.Length: length_or_char_length_sql, - exp.LogicalOr: rename_func("MAX"), - exp.LogicalAnd: rename_func("MIN"), - exp.Max: max_or_greatest, - exp.Min: min_or_least, - exp.Month: _remove_ts_or_ds_to_date(), - exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"), - exp.NullSafeNEQ: lambda self, e: f"NOT {self.binary(e, '<=>')}", - exp.NumberToStr: rename_func("FORMAT"), - exp.Pivot: no_pivot_sql, - exp.Select: transforms.preprocess( - [ - transforms.eliminate_distinct_on, - transforms.eliminate_semi_and_anti_joins, - transforms.eliminate_qualify, - transforms.eliminate_full_outer_join, - transforms.unnest_generate_date_array_using_recursive_cte, - ] - ), - exp.StrPosition: lambda self, e: strposition_sql( - self, e, func_name="LOCATE", supports_position=True - ), - exp.StrToDate: _str_to_date_sql, - exp.StrToTime: _str_to_date_sql, - exp.Stuff: rename_func("INSERT"), - exp.TableSample: no_tablesample_sql, - exp.TimeFromParts: rename_func("MAKETIME"), - exp.TimestampAdd: date_add_interval_sql("DATE", "ADD"), - exp.TimestampDiff: lambda self, e: self.func( - "TIMESTAMPDIFF", unit_to_var(e), e.expression, e.this - ), - exp.TimestampSub: date_add_interval_sql("DATE", "SUB"), - exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), - exp.TimeStrToTime: lambda self, e: timestrtotime_sql( - self, - e, - include_precision=not e.args.get("zone"), - ), - exp.TimeToStr: _remove_ts_or_ds_to_date( - lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)) - ), - exp.Trim: trim_sql, - exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: date_add_sql("ADD"), - exp.TsOrDsDiff: lambda self, e: self.func("DATEDIFF", e.this, e.expression), - exp.TsOrDsToDate: _ts_or_ds_to_date_sql, - exp.Unicode: lambda self, e: f"ORD(CONVERT({self.sql(e.this)} USING utf32))", - exp.UnixToTime: _unix_to_time_sql, - exp.Week: _remove_ts_or_ds_to_date(), - exp.WeekOfYear: _remove_ts_or_ds_to_date(rename_func("WEEKOFYEAR")), - exp.Year: _remove_ts_or_ds_to_date(), - } - - UNSIGNED_TYPE_MAPPING = { - exp.DataType.Type.UBIGINT: "BIGINT", - exp.DataType.Type.UINT: "INT", - exp.DataType.Type.UMEDIUMINT: "MEDIUMINT", - exp.DataType.Type.USMALLINT: "SMALLINT", - exp.DataType.Type.UTINYINT: "TINYINT", - exp.DataType.Type.UDECIMAL: "DECIMAL", - exp.DataType.Type.UDOUBLE: "DOUBLE", - } - - TIMESTAMP_TYPE_MAPPING = { - exp.DataType.Type.DATETIME2: "DATETIME", - exp.DataType.Type.SMALLDATETIME: "DATETIME", - exp.DataType.Type.TIMESTAMP: "DATETIME", - exp.DataType.Type.TIMESTAMPNTZ: "DATETIME", - exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", - exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP", - } - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - **UNSIGNED_TYPE_MAPPING, - **TIMESTAMP_TYPE_MAPPING, - } - - TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT) - TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT) - TYPE_MAPPING.pop(exp.DataType.Type.TINYTEXT) - TYPE_MAPPING.pop(exp.DataType.Type.BLOB) - TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB) - TYPE_MAPPING.pop(exp.DataType.Type.LONGBLOB) - TYPE_MAPPING.pop(exp.DataType.Type.TINYBLOB) - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, - } - - LIMIT_FETCH = "LIMIT" - - LIMIT_ONLY_LITERALS = True - - CHAR_CAST_MAPPING = dict.fromkeys( - ( - exp.DataType.Type.LONGTEXT, - exp.DataType.Type.LONGBLOB, - exp.DataType.Type.MEDIUMBLOB, - exp.DataType.Type.MEDIUMTEXT, - exp.DataType.Type.TEXT, - exp.DataType.Type.TINYBLOB, - exp.DataType.Type.TINYTEXT, - exp.DataType.Type.VARCHAR, - ), - "CHAR", - ) - SIGNED_CAST_MAPPING = dict.fromkeys( - ( - exp.DataType.Type.BIGINT, - exp.DataType.Type.BOOLEAN, - exp.DataType.Type.INT, - exp.DataType.Type.SMALLINT, - exp.DataType.Type.TINYINT, - exp.DataType.Type.MEDIUMINT, - ), - "SIGNED", - ) - - # MySQL doesn't support many datatypes in cast. - # https://dev.mysql.com/doc/refman/8.0/en/cast-functions.html#function_cast - CAST_MAPPING = { - **CHAR_CAST_MAPPING, - **SIGNED_CAST_MAPPING, - exp.DataType.Type.UBIGINT: "UNSIGNED", - } - - TIMESTAMP_FUNC_TYPES = { - exp.DataType.Type.TIMESTAMPTZ, - exp.DataType.Type.TIMESTAMPLTZ, - } - - # https://dev.mysql.com/doc/refman/8.0/en/keywords.html - RESERVED_KEYWORDS = { - "accessible", - "add", - "all", - "alter", - "analyze", - "and", - "as", - "asc", - "asensitive", - "before", - "between", - "bigint", - "binary", - "blob", - "both", - "by", - "call", - "cascade", - "case", - "change", - "char", - "character", - "check", - "collate", - "column", - "condition", - "constraint", - "continue", - "convert", - "create", - "cross", - "cube", - "cume_dist", - "current_date", - "current_time", - "current_timestamp", - "current_user", - "cursor", - "database", - "databases", - "day_hour", - "day_microsecond", - "day_minute", - "day_second", - "dec", - "decimal", - "declare", - "default", - "delayed", - "delete", - "dense_rank", - "desc", - "describe", - "deterministic", - "distinct", - "distinctrow", - "div", - "double", - "drop", - "dual", - "each", - "else", - "elseif", - "empty", - "enclosed", - "escaped", - "except", - "exists", - "exit", - "explain", - "false", - "fetch", - "first_value", - "float", - "float4", - "float8", - "for", - "force", - "foreign", - "from", - "fulltext", - "function", - "generated", - "get", - "grant", - "group", - "grouping", - "groups", - "having", - "high_priority", - "hour_microsecond", - "hour_minute", - "hour_second", - "if", - "ignore", - "in", - "index", - "infile", - "inner", - "inout", - "insensitive", - "insert", - "int", - "int1", - "int2", - "int3", - "int4", - "int8", - "integer", - "intersect", - "interval", - "into", - "io_after_gtids", - "io_before_gtids", - "is", - "iterate", - "join", - "json_table", - "key", - "keys", - "kill", - "lag", - "last_value", - "lateral", - "lead", - "leading", - "leave", - "left", - "like", - "limit", - "linear", - "lines", - "load", - "localtime", - "localtimestamp", - "lock", - "long", - "longblob", - "longtext", - "loop", - "low_priority", - "master_bind", - "master_ssl_verify_server_cert", - "match", - "maxvalue", - "mediumblob", - "mediumint", - "mediumtext", - "middleint", - "minute_microsecond", - "minute_second", - "mod", - "modifies", - "natural", - "not", - "no_write_to_binlog", - "nth_value", - "ntile", - "null", - "numeric", - "of", - "on", - "optimize", - "optimizer_costs", - "option", - "optionally", - "or", - "order", - "out", - "outer", - "outfile", - "over", - "partition", - "percent_rank", - "precision", - "primary", - "procedure", - "purge", - "range", - "rank", - "read", - "reads", - "read_write", - "real", - "recursive", - "references", - "regexp", - "release", - "rename", - "repeat", - "replace", - "require", - "resignal", - "restrict", - "return", - "revoke", - "right", - "rlike", - "row", - "rows", - "row_number", - "schema", - "schemas", - "second_microsecond", - "select", - "sensitive", - "separator", - "set", - "show", - "signal", - "smallint", - "spatial", - "specific", - "sql", - "sqlexception", - "sqlstate", - "sqlwarning", - "sql_big_result", - "sql_calc_found_rows", - "sql_small_result", - "ssl", - "starting", - "stored", - "straight_join", - "system", - "table", - "terminated", - "then", - "tinyblob", - "tinyint", - "tinytext", - "to", - "trailing", - "trigger", - "true", - "undo", - "union", - "unique", - "unlock", - "unsigned", - "update", - "usage", - "use", - "using", - "utc_date", - "utc_time", - "utc_timestamp", - "values", - "varbinary", - "varchar", - "varcharacter", - "varying", - "virtual", - "when", - "where", - "while", - "window", - "with", - "write", - "xor", - "year_month", - "zerofill", - } - - def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str: - persisted = "STORED" if expression.args.get("persisted") else "VIRTUAL" - return f"GENERATED ALWAYS AS ({self.sql(expression.this.unnest())}) {persisted}" - - def array_sql(self, expression: exp.Array) -> str: - self.unsupported("Arrays are not supported by MySQL") - return self.function_fallback_sql(expression) - - def arraycontainsall_sql(self, expression: exp.ArrayContainsAll) -> str: - self.unsupported("Array operations are not supported by MySQL") - return self.function_fallback_sql(expression) - - def dpipe_sql(self, expression: exp.DPipe) -> str: - return self.func("CONCAT", *expression.flatten()) - - def extract_sql(self, expression: exp.Extract) -> str: - unit = expression.name - if unit and unit.lower() == "epoch": - return self.func("UNIX_TIMESTAMP", expression.expression) - - return super().extract_sql(expression) - - def datatype_sql(self, expression: exp.DataType) -> str: - if ( - self.VARCHAR_REQUIRES_SIZE - and expression.is_type(exp.DataType.Type.VARCHAR) - and not expression.expressions - ): - # `VARCHAR` must always have a size - if it doesn't, we always generate `TEXT` - return "TEXT" - - # https://dev.mysql.com/doc/refman/8.0/en/numeric-type-syntax.html - result = super().datatype_sql(expression) - if expression.this in self.UNSIGNED_TYPE_MAPPING: - result = f"{result} UNSIGNED" - - return result - - def jsonarraycontains_sql(self, expression: exp.JSONArrayContains) -> str: - return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})" - - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - if expression.to.this in self.TIMESTAMP_FUNC_TYPES: - return self.func("TIMESTAMP", expression.this) - - to = self.CAST_MAPPING.get(expression.to.this) - - if to: - expression.to.set("this", to) - return super().cast_sql(expression) - - def show_sql(self, expression: exp.Show) -> str: - this = f" {expression.name}" - full = " FULL" if expression.args.get("full") else "" - global_ = " GLOBAL" if expression.args.get("global") else "" - - target = self.sql(expression, "target") - target = f" {target}" if target else "" - if expression.name in ("COLUMNS", "INDEX"): - target = f" FROM{target}" - elif expression.name == "GRANTS": - target = f" FOR{target}" - - db = self._prefixed_sql("FROM", expression, "db") - - like = self._prefixed_sql("LIKE", expression, "like") - where = self.sql(expression, "where") - - types = self.expressions(expression, key="types") - types = f" {types}" if types else types - query = self._prefixed_sql("FOR QUERY", expression, "query") - - if expression.name == "PROFILE": - offset = self._prefixed_sql("OFFSET", expression, "offset") - limit = self._prefixed_sql("LIMIT", expression, "limit") - else: - offset = "" - limit = self._oldstyle_limit_sql(expression) - - log = self._prefixed_sql("IN", expression, "log") - position = self._prefixed_sql("FROM", expression, "position") - - channel = self._prefixed_sql("FOR CHANNEL", expression, "channel") - - if expression.name == "ENGINE": - mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS" - else: - mutex_or_status = "" - - return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}" - - def altercolumn_sql(self, expression: exp.AlterColumn) -> str: - dtype = self.sql(expression, "dtype") - if not dtype: - return super().altercolumn_sql(expression) - - this = self.sql(expression, "this") - return f"MODIFY COLUMN {this} {dtype}" - - def _prefixed_sql(self, prefix: str, expression: exp.Expression, arg: str) -> str: - sql = self.sql(expression, arg) - return f" {prefix} {sql}" if sql else "" - - def _oldstyle_limit_sql(self, expression: exp.Show) -> str: - limit = self.sql(expression, "limit") - offset = self.sql(expression, "offset") - if limit: - limit_offset = f"{offset}, {limit}" if offset else limit - return f" LIMIT {limit_offset}" - return "" - - def chr_sql(self, expression: exp.Chr) -> str: - this = self.expressions(sqls=[expression.this] + expression.expressions) - charset = expression.args.get("charset") - using = f" USING {self.sql(charset)}" if charset else "" - return f"CHAR({this}{using})" - - def timestamptrunc_sql(self, expression: exp.TimestampTrunc) -> str: - unit = expression.args.get("unit") - - # Pick an old-enough date to avoid negative timestamp diffs - start_ts = "'0000-01-01 00:00:00'" - - # Source: https://stackoverflow.com/a/32955740 - timestamp_diff = build_date_delta(exp.TimestampDiff)([unit, start_ts, expression.this]) - interval = exp.Interval(this=timestamp_diff, unit=unit) - dateadd = build_date_delta_with_interval(exp.DateAdd)([start_ts, interval]) - - return self.sql(dateadd) - - def converttimezone_sql(self, expression: exp.ConvertTimezone) -> str: - from_tz = expression.args.get("source_tz") - to_tz = expression.args.get("target_tz") - dt = expression.args.get("timestamp") - - return self.func("CONVERT_TZ", dt, from_tz, to_tz) - - def attimezone_sql(self, expression: exp.AtTimeZone) -> str: - self.unsupported("AT TIME ZONE is not supported by MySQL") - return self.sql(expression.this) - - def isascii_sql(self, expression: exp.IsAscii) -> str: - return f"REGEXP_LIKE({self.sql(expression.this)}, '^[[:ascii:]]*$')" - - @unsupported_args("this") - def currentschema_sql(self, expression: exp.CurrentSchema) -> str: - return self.func("SCHEMA") diff --git a/altimate_packages/sqlglot/dialects/oracle.py b/altimate_packages/sqlglot/dialects/oracle.py deleted file mode 100644 index ec2f77f67..000000000 --- a/altimate_packages/sqlglot/dialects/oracle.py +++ /dev/null @@ -1,378 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - Dialect, - NormalizationStrategy, - build_timetostr_or_tochar, - build_formatted_time, - no_ilike_sql, - rename_func, - strposition_sql, - to_number_with_nls_param, - trim_sql, -) -from sqlglot.helper import seq_get -from sqlglot.parser import OPTIONS_TYPE, build_coalesce -from sqlglot.tokens import TokenType - -if t.TYPE_CHECKING: - from sqlglot._typing import E - - -def _trim_sql(self: Oracle.Generator, expression: exp.Trim) -> str: - position = expression.args.get("position") - - if position and position.upper() in ("LEADING", "TRAILING"): - return self.trim_sql(expression) - - return trim_sql(self, expression) - - -def _build_to_timestamp(args: t.List) -> exp.StrToTime | exp.Anonymous: - if len(args) == 1: - return exp.Anonymous(this="TO_TIMESTAMP", expressions=args) - - return build_formatted_time(exp.StrToTime, "oracle")(args) - - -class Oracle(Dialect): - ALIAS_POST_TABLESAMPLE = True - LOCKING_READS_SUPPORTED = True - TABLESAMPLE_SIZE_IS_PERCENT = True - NULL_ORDERING = "nulls_are_large" - ON_CONDITION_EMPTY_BEFORE_ERROR = False - ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False - - # See section 8: https://docs.oracle.com/cd/A97630_01/server.920/a96540/sql_elements9a.htm - NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE - - # https://docs.oracle.com/database/121/SQLRF/sql_elements004.htm#SQLRF00212 - # https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes - TIME_MAPPING = { - "AM": "%p", # Meridian indicator with or without periods - "A.M.": "%p", # Meridian indicator with or without periods - "PM": "%p", # Meridian indicator with or without periods - "P.M.": "%p", # Meridian indicator with or without periods - "D": "%u", # Day of week (1-7) - "DAY": "%A", # name of day - "DD": "%d", # day of month (1-31) - "DDD": "%j", # day of year (1-366) - "DY": "%a", # abbreviated name of day - "HH": "%I", # Hour of day (1-12) - "HH12": "%I", # alias for HH - "HH24": "%H", # Hour of day (0-23) - "IW": "%V", # Calendar week of year (1-52 or 1-53), as defined by the ISO 8601 standard - "MI": "%M", # Minute (0-59) - "MM": "%m", # Month (01-12; January = 01) - "MON": "%b", # Abbreviated name of month - "MONTH": "%B", # Name of month - "SS": "%S", # Second (0-59) - "WW": "%W", # Week of year (1-53) - "YY": "%y", # 15 - "YYYY": "%Y", # 2015 - "FF6": "%f", # only 6 digits are supported in python formats - } - - class Tokenizer(tokens.Tokenizer): - VAR_SINGLE_TOKENS = {"@", "$", "#"} - - UNICODE_STRINGS = [ - (prefix + q, q) - for q in t.cast(t.List[str], tokens.Tokenizer.QUOTES) - for prefix in ("U", "u") - ] - - NESTED_COMMENTS = False - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - "(+)": TokenType.JOIN_MARKER, - "BINARY_DOUBLE": TokenType.DOUBLE, - "BINARY_FLOAT": TokenType.FLOAT, - "BULK COLLECT INTO": TokenType.BULK_COLLECT_INTO, - "COLUMNS": TokenType.COLUMN, - "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, - "MINUS": TokenType.EXCEPT, - "NVARCHAR2": TokenType.NVARCHAR, - "ORDER SIBLINGS BY": TokenType.ORDER_SIBLINGS_BY, - "SAMPLE": TokenType.TABLE_SAMPLE, - "START": TokenType.BEGIN, - "TOP": TokenType.TOP, - "VARCHAR2": TokenType.VARCHAR, - } - - class Parser(parser.Parser): - WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER, TokenType.KEEP} - VALUES_FOLLOWED_BY_PAREN = False - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "CONVERT": exp.ConvertToCharset.from_arg_list, - "NVL": lambda args: build_coalesce(args, is_nvl=True), - "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), - "TO_CHAR": build_timetostr_or_tochar, - "TO_TIMESTAMP": _build_to_timestamp, - "TO_DATE": build_formatted_time(exp.StrToDate, "oracle"), - "TRUNC": lambda args: exp.DateTrunc( - unit=seq_get(args, 1) or exp.Literal.string("DD"), - this=seq_get(args, 0), - unabbreviate=False, - ), - } - - NO_PAREN_FUNCTION_PARSERS = { - **parser.Parser.NO_PAREN_FUNCTION_PARSERS, - "NEXT": lambda self: self._parse_next_value_for(), - "PRIOR": lambda self: self.expression(exp.Prior, this=self._parse_bitwise()), - "SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, sysdate=True), - } - - FUNCTION_PARSERS: t.Dict[str, t.Callable] = { - **parser.Parser.FUNCTION_PARSERS, - "JSON_ARRAY": lambda self: self._parse_json_array( - exp.JSONArray, - expressions=self._parse_csv(lambda: self._parse_format_json(self._parse_bitwise())), - ), - "JSON_ARRAYAGG": lambda self: self._parse_json_array( - exp.JSONArrayAgg, - this=self._parse_format_json(self._parse_bitwise()), - order=self._parse_order(), - ), - "JSON_EXISTS": lambda self: self._parse_json_exists(), - } - FUNCTION_PARSERS.pop("CONVERT") - - PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, - "GLOBAL": lambda self: self._match_text_seq("TEMPORARY") - and self.expression(exp.TemporaryProperty, this="GLOBAL"), - "PRIVATE": lambda self: self._match_text_seq("TEMPORARY") - and self.expression(exp.TemporaryProperty, this="PRIVATE"), - "FORCE": lambda self: self.expression(exp.ForceProperty), - } - - QUERY_MODIFIER_PARSERS = { - **parser.Parser.QUERY_MODIFIER_PARSERS, - TokenType.ORDER_SIBLINGS_BY: lambda self: ("order", self._parse_order()), - TokenType.WITH: lambda self: ("options", [self._parse_query_restrictions()]), - } - - TYPE_LITERAL_PARSERS = { - exp.DataType.Type.DATE: lambda self, this, _: self.expression( - exp.DateStrToDate, this=this - ) - } - - # SELECT UNIQUE .. is old-style Oracle syntax for SELECT DISTINCT .. - # Reference: https://stackoverflow.com/a/336455 - DISTINCT_TOKENS = {TokenType.DISTINCT, TokenType.UNIQUE} - - QUERY_RESTRICTIONS: OPTIONS_TYPE = { - "WITH": ( - ("READ", "ONLY"), - ("CHECK", "OPTION"), - ), - } - - def _parse_json_array(self, expr_type: t.Type[E], **kwargs) -> E: - return self.expression( - expr_type, - null_handling=self._parse_on_handling("NULL", "NULL", "ABSENT"), - return_type=self._match_text_seq("RETURNING") and self._parse_type(), - strict=self._match_text_seq("STRICT"), - **kwargs, - ) - - def _parse_hint_function_call(self) -> t.Optional[exp.Expression]: - if not self._curr or not self._next or self._next.token_type != TokenType.L_PAREN: - return None - - this = self._curr.text - - self._advance(2) - args = self._parse_hint_args() - this = self.expression(exp.Anonymous, this=this, expressions=args) - self._match_r_paren(this) - return this - - def _parse_hint_args(self): - args = [] - result = self._parse_var() - - while result: - args.append(result) - result = self._parse_var() - - return args - - def _parse_query_restrictions(self) -> t.Optional[exp.Expression]: - kind = self._parse_var_from_options(self.QUERY_RESTRICTIONS, raise_unmatched=False) - - if not kind: - return None - - return self.expression( - exp.QueryOption, - this=kind, - expression=self._match(TokenType.CONSTRAINT) and self._parse_field(), - ) - - def _parse_json_exists(self) -> exp.JSONExists: - this = self._parse_format_json(self._parse_bitwise()) - self._match(TokenType.COMMA) - return self.expression( - exp.JSONExists, - this=this, - path=self.dialect.to_json_path(self._parse_bitwise()), - passing=self._match_text_seq("PASSING") - and self._parse_csv(lambda: self._parse_alias(self._parse_bitwise())), - on_condition=self._parse_on_condition(), - ) - - def _parse_into(self) -> t.Optional[exp.Into]: - # https://docs.oracle.com/en/database/oracle/oracle-database/19/lnpls/SELECT-INTO-statement.html - bulk_collect = self._match(TokenType.BULK_COLLECT_INTO) - if not bulk_collect and not self._match(TokenType.INTO): - return None - - index = self._index - - expressions = self._parse_expressions() - if len(expressions) == 1: - self._retreat(index) - self._match(TokenType.TABLE) - return self.expression( - exp.Into, this=self._parse_table(schema=True), bulk_collect=bulk_collect - ) - - return self.expression(exp.Into, bulk_collect=bulk_collect, expressions=expressions) - - def _parse_connect_with_prior(self): - return self._parse_assignment() - - class Generator(generator.Generator): - LOCKING_READS_SUPPORTED = True - JOIN_HINTS = False - TABLE_HINTS = False - DATA_TYPE_SPECIFIERS_ALLOWED = True - ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False - LIMIT_FETCH = "FETCH" - TABLESAMPLE_KEYWORDS = "SAMPLE" - LAST_DAY_SUPPORTS_DATE_PART = False - SUPPORTS_SELECT_INTO = True - TZ_TO_WITH_TIME_ZONE = True - SUPPORTS_WINDOW_EXCLUDE = True - QUERY_HINT_SEP = " " - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.TINYINT: "SMALLINT", - exp.DataType.Type.SMALLINT: "SMALLINT", - exp.DataType.Type.INT: "INT", - exp.DataType.Type.BIGINT: "INT", - exp.DataType.Type.DECIMAL: "NUMBER", - exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", - exp.DataType.Type.VARCHAR: "VARCHAR2", - exp.DataType.Type.NVARCHAR: "NVARCHAR2", - exp.DataType.Type.NCHAR: "NCHAR", - exp.DataType.Type.TEXT: "CLOB", - exp.DataType.Type.TIMETZ: "TIME", - exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP", - exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", - exp.DataType.Type.BINARY: "BLOB", - exp.DataType.Type.VARBINARY: "BLOB", - exp.DataType.Type.ROWVERSION: "BLOB", - } - TYPE_MAPPING.pop(exp.DataType.Type.BLOB) - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.DateStrToDate: lambda self, e: self.func( - "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD") - ), - exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.unit), - exp.Group: transforms.preprocess([transforms.unalias_group]), - exp.ILike: no_ilike_sql, - exp.LogicalOr: rename_func("MAX"), - exp.LogicalAnd: rename_func("MIN"), - exp.Mod: rename_func("MOD"), - exp.Select: transforms.preprocess( - [ - transforms.eliminate_distinct_on, - transforms.eliminate_qualify, - ] - ), - exp.StrPosition: lambda self, e: ( - strposition_sql( - self, e, func_name="INSTR", supports_position=True, supports_occurrence=True - ) - ), - exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), - exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)), - exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "), - exp.Substring: rename_func("SUBSTR"), - exp.Table: lambda self, e: self.table_sql(e, sep=" "), - exp.TableSample: lambda self, e: self.tablesample_sql(e), - exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY", - exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)), - exp.ToChar: lambda self, e: self.function_fallback_sql(e), - exp.ToNumber: to_number_with_nls_param, - exp.Trim: _trim_sql, - exp.Unicode: lambda self, e: f"ASCII(UNISTR({self.sql(e.this)}))", - exp.UnixToTime: lambda self, - e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)", - } - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, - } - - def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str: - if expression.args.get("sysdate"): - return "SYSDATE" - - this = expression.this - return self.func("CURRENT_TIMESTAMP", this) if this else "CURRENT_TIMESTAMP" - - def offset_sql(self, expression: exp.Offset) -> str: - return f"{super().offset_sql(expression)} ROWS" - - def add_column_sql(self, expression: exp.Expression) -> str: - return f"ADD {self.sql(expression)}" - - def queryoption_sql(self, expression: exp.QueryOption) -> str: - option = self.sql(expression, "this") - value = self.sql(expression, "expression") - value = f" CONSTRAINT {value}" if value else "" - - return f"{option}{value}" - - def coalesce_sql(self, expression: exp.Coalesce) -> str: - func_name = "NVL" if expression.args.get("is_nvl") else "COALESCE" - return rename_func(func_name)(self, expression) - - def into_sql(self, expression: exp.Into) -> str: - into = "INTO" if not expression.args.get("bulk_collect") else "BULK COLLECT INTO" - if expression.this: - return f"{self.seg(into)} {self.sql(expression, 'this')}" - - return f"{self.seg(into)} {self.expressions(expression)}" - - def hint_sql(self, expression: exp.Hint) -> str: - expressions = [] - - for expression in expression.expressions: - if isinstance(expression, exp.Anonymous): - formatted_args = self.format_args(*expression.expressions, sep=" ") - expressions.append(f"{self.sql(expression, 'this')}({formatted_args})") - else: - expressions.append(self.sql(expression)) - - return f" /*+ {self.expressions(sqls=expressions, sep=self.QUERY_HINT_SEP).strip()} */" - - def isascii_sql(self, expression: exp.IsAscii) -> str: - return f"NVL(REGEXP_LIKE({self.sql(expression.this)}, '^[' || CHR(1) || '-' || CHR(127) || ']*$'), TRUE)" diff --git a/altimate_packages/sqlglot/dialects/postgres.py b/altimate_packages/sqlglot/dialects/postgres.py deleted file mode 100644 index 037916162..000000000 --- a/altimate_packages/sqlglot/dialects/postgres.py +++ /dev/null @@ -1,778 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - DATE_ADD_OR_SUB, - Dialect, - JSON_EXTRACT_TYPE, - any_value_to_max_sql, - binary_from_function, - bool_xor_sql, - datestrtodate_sql, - build_formatted_time, - filter_array_using_unnest, - inline_array_sql, - json_extract_segments, - json_path_key_only_name, - max_or_greatest, - merge_without_target_sql, - min_or_least, - no_last_day_sql, - no_map_from_entries_sql, - no_paren_current_date_sql, - no_pivot_sql, - no_trycast_sql, - build_json_extract_path, - build_timestamp_trunc, - rename_func, - sha256_sql, - struct_extract_sql, - timestamptrunc_sql, - timestrtotime_sql, - trim_sql, - ts_or_ds_add_cast, - strposition_sql, - count_if_to_sum, - groupconcat_sql, - Version, -) -from sqlglot.generator import unsupported_args -from sqlglot.helper import is_int, seq_get -from sqlglot.parser import binary_range_parser -from sqlglot.tokens import TokenType - -if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import DialectType - - -DATE_DIFF_FACTOR = { - "MICROSECOND": " * 1000000", - "MILLISECOND": " * 1000", - "SECOND": "", - "MINUTE": " / 60", - "HOUR": " / 3600", - "DAY": " / 86400", -} - - -def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, DATE_ADD_OR_SUB], str]: - def func(self: Postgres.Generator, expression: DATE_ADD_OR_SUB) -> str: - if isinstance(expression, exp.TsOrDsAdd): - expression = ts_or_ds_add_cast(expression) - - this = self.sql(expression, "this") - unit = expression.args.get("unit") - - e = self._simplify_unless_literal(expression.expression) - if isinstance(e, exp.Literal): - e.args["is_string"] = True - elif e.is_number: - e = exp.Literal.string(e.to_py()) - else: - self.unsupported("Cannot add non literal") - - return f"{this} {kind} {self.sql(exp.Interval(this=e, unit=unit))}" - - return func - - -def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str: - unit = expression.text("unit").upper() - factor = DATE_DIFF_FACTOR.get(unit) - - end = f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" - start = f"CAST({self.sql(expression, 'expression')} AS TIMESTAMP)" - - if factor is not None: - return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)" - - age = f"AGE({end}, {start})" - - if unit == "WEEK": - unit = f"EXTRACT(days FROM ({end} - {start})) / 7" - elif unit == "MONTH": - unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})" - elif unit == "QUARTER": - unit = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3" - elif unit == "YEAR": - unit = f"EXTRACT(year FROM {age})" - else: - unit = age - - return f"CAST({unit} AS BIGINT)" - - -def _substring_sql(self: Postgres.Generator, expression: exp.Substring) -> str: - this = self.sql(expression, "this") - start = self.sql(expression, "start") - length = self.sql(expression, "length") - - from_part = f" FROM {start}" if start else "" - for_part = f" FOR {length}" if length else "" - - return f"SUBSTRING({this}{from_part}{for_part})" - - -def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression: - auto = expression.find(exp.AutoIncrementColumnConstraint) - - if auto: - expression.args["constraints"].remove(auto.parent) - kind = expression.args["kind"] - - if kind.this == exp.DataType.Type.INT: - kind.replace(exp.DataType(this=exp.DataType.Type.SERIAL)) - elif kind.this == exp.DataType.Type.SMALLINT: - kind.replace(exp.DataType(this=exp.DataType.Type.SMALLSERIAL)) - elif kind.this == exp.DataType.Type.BIGINT: - kind.replace(exp.DataType(this=exp.DataType.Type.BIGSERIAL)) - - return expression - - -def _serial_to_generated(expression: exp.Expression) -> exp.Expression: - if not isinstance(expression, exp.ColumnDef): - return expression - kind = expression.kind - if not kind: - return expression - - if kind.this == exp.DataType.Type.SERIAL: - data_type = exp.DataType(this=exp.DataType.Type.INT) - elif kind.this == exp.DataType.Type.SMALLSERIAL: - data_type = exp.DataType(this=exp.DataType.Type.SMALLINT) - elif kind.this == exp.DataType.Type.BIGSERIAL: - data_type = exp.DataType(this=exp.DataType.Type.BIGINT) - else: - data_type = None - - if data_type: - expression.args["kind"].replace(data_type) - constraints = expression.args["constraints"] - generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False)) - notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint()) - - if notnull not in constraints: - constraints.insert(0, notnull) - if generated not in constraints: - constraints.insert(0, generated) - - return expression - - -def _build_generate_series(args: t.List) -> exp.ExplodingGenerateSeries: - # The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day - # Note: postgres allows calls with just two arguments -- the "step" argument defaults to 1 - step = seq_get(args, 2) - if step is not None: - if step.is_string: - args[2] = exp.to_interval(step.this) - elif isinstance(step, exp.Interval) and not step.args.get("unit"): - args[2] = exp.to_interval(step.this.this) - - return exp.ExplodingGenerateSeries.from_arg_list(args) - - -def _build_to_timestamp(args: t.List) -> exp.UnixToTime | exp.StrToTime: - # TO_TIMESTAMP accepts either a single double argument or (text, text) - if len(args) == 1: - # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE - return exp.UnixToTime.from_arg_list(args) - - # https://www.postgresql.org/docs/current/functions-formatting.html - return build_formatted_time(exp.StrToTime, "postgres")(args) - - -def _json_extract_sql( - name: str, op: str -) -> t.Callable[[Postgres.Generator, JSON_EXTRACT_TYPE], str]: - def _generate(self: Postgres.Generator, expression: JSON_EXTRACT_TYPE) -> str: - if expression.args.get("only_json_types"): - return json_extract_segments(name, quoted_index=False, op=op)(self, expression) - return json_extract_segments(name)(self, expression) - - return _generate - - -def _build_regexp_replace(args: t.List, dialect: DialectType = None) -> exp.RegexpReplace: - # The signature of REGEXP_REPLACE is: - # regexp_replace(source, pattern, replacement [, start [, N ]] [, flags ]) - # - # Any one of `start`, `N` and `flags` can be column references, meaning that - # unless we can statically see that the last argument is a non-integer string - # (eg. not '0'), then it's not possible to construct the correct AST - if len(args) > 3: - last = args[-1] - if not is_int(last.name): - if not last.type or last.is_type(exp.DataType.Type.UNKNOWN, exp.DataType.Type.NULL): - from sqlglot.optimizer.annotate_types import annotate_types - - last = annotate_types(last, dialect=dialect) - - if last.is_type(*exp.DataType.TEXT_TYPES): - regexp_replace = exp.RegexpReplace.from_arg_list(args[:-1]) - regexp_replace.set("modifiers", last) - return regexp_replace - - return exp.RegexpReplace.from_arg_list(args) - - -def _unix_to_time_sql(self: Postgres.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = expression.this - - if scale in (None, exp.UnixToTime.SECONDS): - return self.func("TO_TIMESTAMP", timestamp, self.format_time(expression)) - - return self.func( - "TO_TIMESTAMP", - exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), - self.format_time(expression), - ) - - -def _build_levenshtein_less_equal(args: t.List) -> exp.Levenshtein: - # Postgres has two signatures for levenshtein_less_equal function, but in both cases - # max_dist is the last argument - # levenshtein_less_equal(source, target, ins_cost, del_cost, sub_cost, max_d) - # levenshtein_less_equal(source, target, max_d) - max_dist = args.pop() - - return exp.Levenshtein( - this=seq_get(args, 0), - expression=seq_get(args, 1), - ins_cost=seq_get(args, 2), - del_cost=seq_get(args, 3), - sub_cost=seq_get(args, 4), - max_dist=max_dist, - ) - - -def _levenshtein_sql(self: Postgres.Generator, expression: exp.Levenshtein) -> str: - name = "LEVENSHTEIN_LESS_EQUAL" if expression.args.get("max_dist") else "LEVENSHTEIN" - - return rename_func(name)(self, expression) - - -def _versioned_anyvalue_sql(self: Postgres.Generator, expression: exp.AnyValue) -> str: - # https://www.postgresql.org/docs/16/functions-aggregate.html - # https://www.postgresql.org/about/featurematrix/ - if self.dialect.version < Version("16.0"): - return any_value_to_max_sql(self, expression) - - return rename_func("ANY_VALUE")(self, expression) - - -class Postgres(Dialect): - INDEX_OFFSET = 1 - TYPED_DIVISION = True - CONCAT_COALESCE = True - NULL_ORDERING = "nulls_are_large" - TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" - TABLESAMPLE_SIZE_IS_PERCENT = True - - TIME_MAPPING = { - "AM": "%p", - "PM": "%p", - "d": "%u", # 1-based day of week - "D": "%u", # 1-based day of week - "dd": "%d", # day of month - "DD": "%d", # day of month - "ddd": "%j", # zero padded day of year - "DDD": "%j", # zero padded day of year - "FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres - "FMDDD": "%-j", # day of year - "FMHH12": "%-I", # 9 - "FMHH24": "%-H", # 9 - "FMMI": "%-M", # Minute - "FMMM": "%-m", # 1 - "FMSS": "%-S", # Second - "HH12": "%I", # 09 - "HH24": "%H", # 09 - "mi": "%M", # zero padded minute - "MI": "%M", # zero padded minute - "mm": "%m", # 01 - "MM": "%m", # 01 - "OF": "%z", # utc offset - "ss": "%S", # zero padded second - "SS": "%S", # zero padded second - "TMDay": "%A", # TM is locale dependent - "TMDy": "%a", - "TMMon": "%b", # Sep - "TMMonth": "%B", # September - "TZ": "%Z", # uppercase timezone name - "US": "%f", # zero padded microsecond - "ww": "%U", # 1-based week of year - "WW": "%U", # 1-based week of year - "yy": "%y", # 15 - "YY": "%y", # 15 - "yyyy": "%Y", # 2015 - "YYYY": "%Y", # 2015 - } - - class Tokenizer(tokens.Tokenizer): - BIT_STRINGS = [("b'", "'"), ("B'", "'")] - HEX_STRINGS = [("x'", "'"), ("X'", "'")] - BYTE_STRINGS = [("e'", "'"), ("E'", "'")] - HEREDOC_STRINGS = ["$"] - - HEREDOC_TAG_IS_IDENTIFIER = True - HEREDOC_STRING_ALTERNATIVE = TokenType.PARAMETER - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - "~": TokenType.RLIKE, - "@@": TokenType.DAT, - "@>": TokenType.AT_GT, - "<@": TokenType.LT_AT, - "|/": TokenType.PIPE_SLASH, - "||/": TokenType.DPIPE_SLASH, - "BEGIN": TokenType.COMMAND, - "BEGIN TRANSACTION": TokenType.BEGIN, - "BIGSERIAL": TokenType.BIGSERIAL, - "CONSTRAINT TRIGGER": TokenType.COMMAND, - "CSTRING": TokenType.PSEUDO_TYPE, - "DECLARE": TokenType.COMMAND, - "DO": TokenType.COMMAND, - "EXEC": TokenType.COMMAND, - "HSTORE": TokenType.HSTORE, - "INT8": TokenType.BIGINT, - "MONEY": TokenType.MONEY, - "NAME": TokenType.NAME, - "OID": TokenType.OBJECT_IDENTIFIER, - "ONLY": TokenType.ONLY, - "OPERATOR": TokenType.OPERATOR, - "REFRESH": TokenType.COMMAND, - "REINDEX": TokenType.COMMAND, - "RESET": TokenType.COMMAND, - "REVOKE": TokenType.COMMAND, - "SERIAL": TokenType.SERIAL, - "SMALLSERIAL": TokenType.SMALLSERIAL, - "TEMP": TokenType.TEMPORARY, - "REGCLASS": TokenType.OBJECT_IDENTIFIER, - "REGCOLLATION": TokenType.OBJECT_IDENTIFIER, - "REGCONFIG": TokenType.OBJECT_IDENTIFIER, - "REGDICTIONARY": TokenType.OBJECT_IDENTIFIER, - "REGNAMESPACE": TokenType.OBJECT_IDENTIFIER, - "REGOPER": TokenType.OBJECT_IDENTIFIER, - "REGOPERATOR": TokenType.OBJECT_IDENTIFIER, - "REGPROC": TokenType.OBJECT_IDENTIFIER, - "REGPROCEDURE": TokenType.OBJECT_IDENTIFIER, - "REGROLE": TokenType.OBJECT_IDENTIFIER, - "REGTYPE": TokenType.OBJECT_IDENTIFIER, - "FLOAT": TokenType.DOUBLE, - } - KEYWORDS.pop("/*+") - KEYWORDS.pop("DIV") - - SINGLE_TOKENS = { - **tokens.Tokenizer.SINGLE_TOKENS, - "$": TokenType.HEREDOC_STRING, - } - - VAR_SINGLE_TOKENS = {"$"} - - class Parser(parser.Parser): - PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, - "SET": lambda self: self.expression(exp.SetConfigProperty, this=self._parse_set()), - } - PROPERTY_PARSERS.pop("INPUT") - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "ASCII": exp.Unicode.from_arg_list, - "DATE_TRUNC": build_timestamp_trunc, - "DIV": lambda args: exp.cast( - binary_from_function(exp.IntDiv)(args), exp.DataType.Type.DECIMAL - ), - "GENERATE_SERIES": _build_generate_series, - "JSON_EXTRACT_PATH": build_json_extract_path(exp.JSONExtract), - "JSON_EXTRACT_PATH_TEXT": build_json_extract_path(exp.JSONExtractScalar), - "LENGTH": lambda args: exp.Length(this=seq_get(args, 0), encoding=seq_get(args, 1)), - "MAKE_TIME": exp.TimeFromParts.from_arg_list, - "MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list, - "NOW": exp.CurrentTimestamp.from_arg_list, - "REGEXP_REPLACE": _build_regexp_replace, - "TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"), - "TO_DATE": build_formatted_time(exp.StrToDate, "postgres"), - "TO_TIMESTAMP": _build_to_timestamp, - "UNNEST": exp.Explode.from_arg_list, - "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), - "SHA384": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(384)), - "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)), - "LEVENSHTEIN_LESS_EQUAL": _build_levenshtein_less_equal, - "JSON_OBJECT_AGG": lambda args: exp.JSONObjectAgg(expressions=args), - "JSONB_OBJECT_AGG": exp.JSONBObjectAgg.from_arg_list, - } - - NO_PAREN_FUNCTIONS = { - **parser.Parser.NO_PAREN_FUNCTIONS, - TokenType.CURRENT_SCHEMA: exp.CurrentSchema, - } - - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, - "DATE_PART": lambda self: self._parse_date_part(), - "JSONB_EXISTS": lambda self: self._parse_jsonb_exists(), - } - - BITWISE = { - **parser.Parser.BITWISE, - TokenType.HASH: exp.BitwiseXor, - } - - EXPONENT = { - TokenType.CARET: exp.Pow, - } - - RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, - TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps), - TokenType.DAT: lambda self, this: self.expression( - exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this] - ), - TokenType.OPERATOR: lambda self, this: self._parse_operator(this), - } - - STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, - TokenType.END: lambda self: self._parse_commit_or_rollback(), - } - - JSON_ARROWS_REQUIRE_JSON_TYPE = True - - COLUMN_OPERATORS = { - **parser.Parser.COLUMN_OPERATORS, - TokenType.ARROW: lambda self, this, path: build_json_extract_path( - exp.JSONExtract, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE - )([this, path]), - TokenType.DARROW: lambda self, this, path: build_json_extract_path( - exp.JSONExtractScalar, arrow_req_json_type=self.JSON_ARROWS_REQUIRE_JSON_TYPE - )([this, path]), - } - - def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - while True: - if not self._match(TokenType.L_PAREN): - break - - op = "" - while self._curr and not self._match(TokenType.R_PAREN): - op += self._curr.text - self._advance() - - this = self.expression( - exp.Operator, - comments=self._prev_comments, - this=this, - operator=op, - expression=self._parse_bitwise(), - ) - - if not self._match(TokenType.OPERATOR): - break - - return this - - def _parse_date_part(self) -> exp.Expression: - part = self._parse_type() - self._match(TokenType.COMMA) - value = self._parse_bitwise() - - if part and isinstance(part, (exp.Column, exp.Literal)): - part = exp.var(part.name) - - return self.expression(exp.Extract, this=part, expression=value) - - def _parse_unique_key(self) -> t.Optional[exp.Expression]: - return None - - def _parse_jsonb_exists(self) -> exp.JSONBExists: - return self.expression( - exp.JSONBExists, - this=self._parse_bitwise(), - path=self._match(TokenType.COMMA) - and self.dialect.to_json_path(self._parse_bitwise()), - ) - - def _parse_generated_as_identity( - self, - ) -> ( - exp.GeneratedAsIdentityColumnConstraint - | exp.ComputedColumnConstraint - | exp.GeneratedAsRowColumnConstraint - ): - this = super()._parse_generated_as_identity() - - if self._match_text_seq("STORED"): - this = self.expression(exp.ComputedColumnConstraint, this=this.expression) - - return this - - def _parse_user_defined_type( - self, identifier: exp.Identifier - ) -> t.Optional[exp.Expression]: - udt_type: exp.Identifier | exp.Dot = identifier - - while self._match(TokenType.DOT): - part = self._parse_id_var() - if part: - udt_type = exp.Dot(this=udt_type, expression=part) - - return exp.DataType.build(udt_type, udt=True) - - class Generator(generator.Generator): - SINGLE_STRING_INTERVAL = True - RENAME_TABLE_WITH_DB = False - LOCKING_READS_SUPPORTED = True - JOIN_HINTS = False - TABLE_HINTS = False - QUERY_HINTS = False - NVL2_SUPPORTED = False - PARAMETER_TOKEN = "$" - TABLESAMPLE_SIZE_IS_ROWS = False - TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" - SUPPORTS_SELECT_INTO = True - JSON_TYPE_REQUIRED_FOR_EXTRACTION = True - SUPPORTS_UNLOGGED_TABLES = True - LIKE_PROPERTY_INSIDE_SCHEMA = True - MULTI_ARG_DISTINCT = False - CAN_IMPLEMENT_ARRAY_ANY = True - SUPPORTS_WINDOW_EXCLUDE = True - COPY_HAS_INTO_KEYWORD = False - ARRAY_CONCAT_IS_VAR_LEN = False - SUPPORTS_MEDIAN = False - ARRAY_SIZE_DIM_REQUIRED = True - - SUPPORTED_JSON_PATH_PARTS = { - exp.JSONPathKey, - exp.JSONPathRoot, - exp.JSONPathSubscript, - } - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.TINYINT: "SMALLINT", - exp.DataType.Type.FLOAT: "REAL", - exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", - exp.DataType.Type.BINARY: "BYTEA", - exp.DataType.Type.VARBINARY: "BYTEA", - exp.DataType.Type.ROWVERSION: "BYTEA", - exp.DataType.Type.DATETIME: "TIMESTAMP", - exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP", - exp.DataType.Type.BLOB: "BYTEA", - } - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.AnyValue: _versioned_anyvalue_sql, - exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CAT"), - exp.ArrayFilter: filter_array_using_unnest, - exp.BitwiseXor: lambda self, e: self.binary(e, "#"), - exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), - exp.CurrentDate: no_paren_current_date_sql, - exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.CurrentUser: lambda *_: "CURRENT_USER", - exp.DateAdd: _date_add_sql("+"), - exp.DateDiff: _date_diff_sql, - exp.DateStrToDate: datestrtodate_sql, - exp.DateSub: _date_add_sql("-"), - exp.Explode: rename_func("UNNEST"), - exp.ExplodingGenerateSeries: rename_func("GENERATE_SERIES"), - exp.GroupConcat: lambda self, e: groupconcat_sql( - self, e, func_name="STRING_AGG", within_group=False - ), - exp.IntDiv: rename_func("DIV"), - exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"), - exp.JSONExtractScalar: _json_extract_sql("JSON_EXTRACT_PATH_TEXT", "->>"), - exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), - exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), - exp.JSONBContains: lambda self, e: self.binary(e, "?"), - exp.ParseJSON: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.JSON)), - exp.JSONPathKey: json_path_key_only_name, - exp.JSONPathRoot: lambda *_: "", - exp.JSONPathSubscript: lambda self, e: self.json_path_part(e.this), - exp.LastDay: no_last_day_sql, - exp.LogicalOr: rename_func("BOOL_OR"), - exp.LogicalAnd: rename_func("BOOL_AND"), - exp.Max: max_or_greatest, - exp.MapFromEntries: no_map_from_entries_sql, - exp.Min: min_or_least, - exp.Merge: merge_without_target_sql, - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", - exp.PercentileCont: transforms.preprocess( - [transforms.add_within_group_for_percentiles] - ), - exp.PercentileDisc: transforms.preprocess( - [transforms.add_within_group_for_percentiles] - ), - exp.Pivot: no_pivot_sql, - exp.Rand: rename_func("RANDOM"), - exp.RegexpLike: lambda self, e: self.binary(e, "~"), - exp.RegexpILike: lambda self, e: self.binary(e, "~*"), - exp.Select: transforms.preprocess( - [ - transforms.eliminate_semi_and_anti_joins, - transforms.eliminate_qualify, - ] - ), - exp.SHA2: sha256_sql, - exp.StrPosition: lambda self, e: strposition_sql(self, e, func_name="POSITION"), - exp.StrToDate: lambda self, e: self.func("TO_DATE", e.this, self.format_time(e)), - exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), - exp.StructExtract: struct_extract_sql, - exp.Substring: _substring_sql, - exp.TimeFromParts: rename_func("MAKE_TIME"), - exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"), - exp.TimestampTrunc: timestamptrunc_sql(zone=True), - exp.TimeStrToTime: timestrtotime_sql, - exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)), - exp.ToChar: lambda self, e: self.function_fallback_sql(e), - exp.Trim: trim_sql, - exp.TryCast: no_trycast_sql, - exp.TsOrDsAdd: _date_add_sql("+"), - exp.TsOrDsDiff: _date_diff_sql, - exp.UnixToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this), - exp.Uuid: lambda *_: "GEN_RANDOM_UUID()", - exp.TimeToUnix: lambda self, e: self.func( - "DATE_PART", exp.Literal.string("epoch"), e.this - ), - exp.VariancePop: rename_func("VAR_POP"), - exp.Variance: rename_func("VAR_SAMP"), - exp.Xor: bool_xor_sql, - exp.Unicode: rename_func("ASCII"), - exp.UnixToTime: _unix_to_time_sql, - exp.Levenshtein: _levenshtein_sql, - exp.JSONObjectAgg: rename_func("JSON_OBJECT_AGG"), - exp.JSONBObjectAgg: rename_func("JSONB_OBJECT_AGG"), - exp.CountIf: count_if_to_sum, - } - - TRANSFORMS.pop(exp.CommentColumnConstraint) - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, - exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, - } - - def schemacommentproperty_sql(self, expression: exp.SchemaCommentProperty) -> str: - self.unsupported("Table comments are not supported in the CREATE statement") - return "" - - def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str: - self.unsupported("Column comments are not supported in the CREATE statement") - return "" - - def unnest_sql(self, expression: exp.Unnest) -> str: - if len(expression.expressions) == 1: - arg = expression.expressions[0] - if isinstance(arg, exp.GenerateDateArray): - generate_series: exp.Expression = exp.GenerateSeries(**arg.args) - if isinstance(expression.parent, (exp.From, exp.Join)): - generate_series = ( - exp.select("value::date") - .from_(exp.Table(this=generate_series).as_("_t", table=["value"])) - .subquery(expression.args.get("alias") or "_unnested_generate_series") - ) - return self.sql(generate_series) - - from sqlglot.optimizer.annotate_types import annotate_types - - this = annotate_types(arg, dialect=self.dialect) - if this.is_type("array"): - while isinstance(this, exp.Cast): - this = this.this - - arg_as_json = self.sql(exp.cast(this, exp.DataType.Type.JSON)) - alias = self.sql(expression, "alias") - alias = f" AS {alias}" if alias else "" - - if expression.args.get("offset"): - self.unsupported("Unsupported JSON_ARRAY_ELEMENTS with offset") - - return f"JSON_ARRAY_ELEMENTS({arg_as_json}){alias}" - - return super().unnest_sql(expression) - - def bracket_sql(self, expression: exp.Bracket) -> str: - """Forms like ARRAY[1, 2, 3][3] aren't allowed; we need to wrap the ARRAY.""" - if isinstance(expression.this, exp.Array): - expression.set("this", exp.paren(expression.this, copy=False)) - - return super().bracket_sql(expression) - - def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: - this = self.sql(expression, "this") - expressions = [f"{self.sql(e)} @@ {this}" for e in expression.expressions] - sql = " OR ".join(expressions) - return f"({sql})" if len(expressions) > 1 else sql - - def alterset_sql(self, expression: exp.AlterSet) -> str: - exprs = self.expressions(expression, flat=True) - exprs = f"({exprs})" if exprs else "" - - access_method = self.sql(expression, "access_method") - access_method = f"ACCESS METHOD {access_method}" if access_method else "" - tablespace = self.sql(expression, "tablespace") - tablespace = f"TABLESPACE {tablespace}" if tablespace else "" - option = self.sql(expression, "option") - - return f"SET {exprs}{access_method}{tablespace}{option}" - - def datatype_sql(self, expression: exp.DataType) -> str: - if expression.is_type(exp.DataType.Type.ARRAY): - if expression.expressions: - values = self.expressions(expression, key="values", flat=True) - return f"{self.expressions(expression, flat=True)}[{values}]" - return "ARRAY" - - if ( - expression.is_type(exp.DataType.Type.DOUBLE, exp.DataType.Type.FLOAT) - and expression.expressions - ): - # Postgres doesn't support precision for REAL and DOUBLE PRECISION types - return f"FLOAT({self.expressions(expression, flat=True)})" - - return super().datatype_sql(expression) - - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - this = expression.this - - # Postgres casts DIV() to decimal for transpilation but when roundtripping it's superfluous - if isinstance(this, exp.IntDiv) and expression.to == exp.DataType.build("decimal"): - return self.sql(this) - - return super().cast_sql(expression, safe_prefix=safe_prefix) - - def array_sql(self, expression: exp.Array) -> str: - exprs = expression.expressions - func_name = self.normalize_func("ARRAY") - - if isinstance(seq_get(exprs, 0), exp.Select): - return f"{func_name}({self.sql(exprs[0])})" - - return f"{func_name}{inline_array_sql(self, expression)}" - - def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str: - return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')}) STORED" - - def isascii_sql(self, expression: exp.IsAscii) -> str: - return f"({self.sql(expression.this)} ~ '^[[:ascii:]]*$')" - - @unsupported_args("this") - def currentschema_sql(self, expression: exp.CurrentSchema) -> str: - return "CURRENT_SCHEMA" - - def interval_sql(self, expression: exp.Interval) -> str: - unit = expression.text("unit").lower() - - if unit.startswith("quarter") and isinstance(expression.this, exp.Literal): - expression.this.replace(exp.Literal.number(int(expression.this.to_py()) * 3)) - expression.args["unit"].replace(exp.var("MONTH")) - - return super().interval_sql(expression) diff --git a/altimate_packages/sqlglot/dialects/presto.py b/altimate_packages/sqlglot/dialects/presto.py deleted file mode 100644 index 543bf2d6c..000000000 --- a/altimate_packages/sqlglot/dialects/presto.py +++ /dev/null @@ -1,788 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - Dialect, - NormalizationStrategy, - binary_from_function, - bool_xor_sql, - date_trunc_to_time, - datestrtodate_sql, - encode_decode_sql, - build_formatted_time, - if_sql, - left_to_substring_sql, - no_ilike_sql, - no_pivot_sql, - no_timestamp_sql, - regexp_extract_sql, - rename_func, - right_to_substring_sql, - sha256_sql, - strposition_sql, - struct_extract_sql, - timestamptrunc_sql, - timestrtotime_sql, - ts_or_ds_add_cast, - unit_to_str, - sequence_sql, - build_regexp_extract, - explode_to_unnest_sql, -) -from sqlglot.dialects.hive import Hive -from sqlglot.dialects.mysql import MySQL -from sqlglot.helper import apply_index_offset, seq_get -from sqlglot.optimizer.scope import find_all_in_scope -from sqlglot.tokens import TokenType -from sqlglot.transforms import unqualify_columns -from sqlglot.generator import unsupported_args - -DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TimestampAdd, exp.DateSub] - - -def _initcap_sql(self: Presto.Generator, expression: exp.Initcap) -> str: - regex = r"(\w)(\w*)" - return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" - - -def _no_sort_array(self: Presto.Generator, expression: exp.SortArray) -> str: - if expression.args.get("asc") == exp.false(): - comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" - else: - comparator = None - return self.func("ARRAY_SORT", expression.this, comparator) - - -def _schema_sql(self: Presto.Generator, expression: exp.Schema) -> str: - if isinstance(expression.parent, exp.PartitionedByProperty): - # Any columns in the ARRAY[] string literals should not be quoted - expression.transform(lambda n: n.name if isinstance(n, exp.Identifier) else n, copy=False) - - partition_exprs = [ - self.sql(c) if isinstance(c, (exp.Func, exp.Property)) else self.sql(c, "this") - for c in expression.expressions - ] - return self.sql(exp.Array(expressions=[exp.Literal.string(c) for c in partition_exprs])) - - if expression.parent: - for schema in expression.parent.find_all(exp.Schema): - if schema is expression: - continue - - column_defs = schema.find_all(exp.ColumnDef) - if column_defs and isinstance(schema.parent, exp.Property): - expression.expressions.extend(column_defs) - - return self.schema_sql(expression) - - -def _quantile_sql(self: Presto.Generator, expression: exp.Quantile) -> str: - self.unsupported("Presto does not support exact quantiles") - return self.func("APPROX_PERCENTILE", expression.this, expression.args.get("quantile")) - - -def _str_to_time_sql( - self: Presto.Generator, expression: exp.StrToDate | exp.StrToTime | exp.TsOrDsToDate -) -> str: - return self.func("DATE_PARSE", expression.this, self.format_time(expression)) - - -def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str: - time_format = self.format_time(expression) - if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT): - return self.sql(exp.cast(_str_to_time_sql(self, expression), exp.DataType.Type.DATE)) - return self.sql( - exp.cast(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), exp.DataType.Type.DATE) - ) - - -def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str: - expression = ts_or_ds_add_cast(expression) - unit = unit_to_str(expression) - return self.func("DATE_ADD", unit, expression.expression, expression.this) - - -def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str: - this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMP) - expr = exp.cast(expression.expression, exp.DataType.Type.TIMESTAMP) - unit = unit_to_str(expression) - return self.func("DATE_DIFF", unit, expr, this) - - -def _build_approx_percentile(args: t.List) -> exp.Expression: - if len(args) == 4: - return exp.ApproxQuantile( - this=seq_get(args, 0), - weight=seq_get(args, 1), - quantile=seq_get(args, 2), - accuracy=seq_get(args, 3), - ) - if len(args) == 3: - return exp.ApproxQuantile( - this=seq_get(args, 0), quantile=seq_get(args, 1), accuracy=seq_get(args, 2) - ) - return exp.ApproxQuantile.from_arg_list(args) - - -def _build_from_unixtime(args: t.List) -> exp.Expression: - if len(args) == 3: - return exp.UnixToTime( - this=seq_get(args, 0), - hours=seq_get(args, 1), - minutes=seq_get(args, 2), - ) - if len(args) == 2: - return exp.UnixToTime(this=seq_get(args, 0), zone=seq_get(args, 1)) - - return exp.UnixToTime.from_arg_list(args) - - -def _first_last_sql(self: Presto.Generator, expression: exp.Func) -> str: - """ - Trino doesn't support FIRST / LAST as functions, but they're valid in the context - of MATCH_RECOGNIZE, so we need to preserve them in that case. In all other cases - they're converted into an ARBITRARY call. - - Reference: https://trino.io/docs/current/sql/match-recognize.html#logical-navigation-functions - """ - if isinstance(expression.find_ancestor(exp.MatchRecognize, exp.Select), exp.MatchRecognize): - return self.function_fallback_sql(expression) - - return rename_func("ARBITRARY")(self, expression) - - -def _unix_to_time_sql(self: Presto.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = self.sql(expression, "this") - if scale in (None, exp.UnixToTime.SECONDS): - return rename_func("FROM_UNIXTIME")(self, expression) - - return f"FROM_UNIXTIME(CAST({timestamp} AS DOUBLE) / POW(10, {scale}))" - - -def _to_int(self: Presto.Generator, expression: exp.Expression) -> exp.Expression: - if not expression.type: - from sqlglot.optimizer.annotate_types import annotate_types - - annotate_types(expression, dialect=self.dialect) - if expression.type and expression.type.this not in exp.DataType.INTEGER_TYPES: - return exp.cast(expression, to=exp.DataType.Type.BIGINT) - return expression - - -def _build_to_char(args: t.List) -> exp.TimeToStr: - fmt = seq_get(args, 1) - if isinstance(fmt, exp.Literal): - # We uppercase this to match Teradata's format mapping keys - fmt.set("this", fmt.this.upper()) - - # We use "teradata" on purpose here, because the time formats are different in Presto. - # See https://prestodb.io/docs/current/functions/teradata.html?highlight=to_char#to_char - return build_formatted_time(exp.TimeToStr, "teradata")(args) - - -def _date_delta_sql( - name: str, negate_interval: bool = False -) -> t.Callable[[Presto.Generator, DATE_ADD_OR_SUB], str]: - def _delta_sql(self: Presto.Generator, expression: DATE_ADD_OR_SUB) -> str: - interval = _to_int(self, expression.expression) - return self.func( - name, - unit_to_str(expression), - interval * (-1) if negate_interval else interval, - expression.this, - ) - - return _delta_sql - - -def _explode_to_unnest_sql(self: Presto.Generator, expression: exp.Lateral) -> str: - explode = expression.this - if isinstance(explode, exp.Explode): - exploded_type = explode.this.type - alias = expression.args.get("alias") - - # This attempts a best-effort transpilation of LATERAL VIEW EXPLODE on a struct array - if ( - isinstance(alias, exp.TableAlias) - and isinstance(exploded_type, exp.DataType) - and exploded_type.is_type(exp.DataType.Type.ARRAY) - and exploded_type.expressions - and exploded_type.expressions[0].is_type(exp.DataType.Type.STRUCT) - ): - # When unnesting a ROW in Presto, it produces N columns, so we need to fix the alias - alias.set("columns", [c.this.copy() for c in exploded_type.expressions[0].expressions]) - elif isinstance(explode, exp.Inline): - explode.replace(exp.Explode(this=explode.this.copy())) - - return explode_to_unnest_sql(self, expression) - - -def amend_exploded_column_table(expression: exp.Expression) -> exp.Expression: - # We check for expression.type because the columns can be amended only if types were inferred - if isinstance(expression, exp.Select) and expression.type: - for lateral in expression.args.get("laterals") or []: - alias = lateral.args.get("alias") - if ( - not isinstance(lateral.this, exp.Explode) - or not isinstance(alias, exp.TableAlias) - or len(alias.columns) != 1 - ): - continue - - new_table = alias.this - old_table = alias.columns[0].name.lower() - - # When transpiling a LATERAL VIEW EXPLODE Spark query, the exploded fields may be qualified - # with the struct column, resulting in invalid Presto references that need to be amended - for column in find_all_in_scope(expression, exp.Column): - if column.db.lower() == old_table: - column.set("table", column.args["db"].pop()) - elif column.table.lower() == old_table: - column.set("table", new_table.copy()) - elif column.name.lower() == old_table and isinstance(column.parent, exp.Dot): - column.parent.replace(exp.column(column.parent.expression, table=new_table)) - - return expression - - -class Presto(Dialect): - INDEX_OFFSET = 1 - NULL_ORDERING = "nulls_are_last" - TIME_FORMAT = MySQL.TIME_FORMAT - STRICT_STRING_CONCAT = True - SUPPORTS_SEMI_ANTI_JOIN = False - TYPED_DIVISION = True - TABLESAMPLE_SIZE_IS_PERCENT = True - LOG_BASE_FIRST: t.Optional[bool] = None - SUPPORTS_VALUES_DEFAULT = False - - TIME_MAPPING = MySQL.TIME_MAPPING - - # https://github.com/trinodb/trino/issues/17 - # https://github.com/trinodb/trino/issues/12289 - # https://github.com/prestodb/presto/issues/2863 - NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE - - # The result of certain math functions in Presto/Trino is of type - # equal to the input type e.g: FLOOR(5.5/2) -> DECIMAL, FLOOR(5/2) -> BIGINT - ANNOTATORS = { - **Dialect.ANNOTATORS, - exp.Floor: lambda self, e: self._annotate_by_args(e, "this"), - exp.Ceil: lambda self, e: self._annotate_by_args(e, "this"), - exp.Mod: lambda self, e: self._annotate_by_args(e, "this", "expression"), - exp.Round: lambda self, e: self._annotate_by_args(e, "this"), - exp.Sign: lambda self, e: self._annotate_by_args(e, "this"), - exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), - exp.Rand: lambda self, e: self._annotate_by_args(e, "this") - if e.this - else self._set_type(e, exp.DataType.Type.DOUBLE), - } - - SUPPORTED_SETTINGS = { - *Dialect.SUPPORTED_SETTINGS, - "variant_extract_is_json_extract", - } - - class Tokenizer(tokens.Tokenizer): - HEX_STRINGS = [("x'", "'"), ("X'", "'")] - UNICODE_STRINGS = [ - (prefix + q, q) - for q in t.cast(t.List[str], tokens.Tokenizer.QUOTES) - for prefix in ("U&", "u&") - ] - - NESTED_COMMENTS = False - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - "DEALLOCATE PREPARE": TokenType.COMMAND, - "DESCRIBE INPUT": TokenType.COMMAND, - "DESCRIBE OUTPUT": TokenType.COMMAND, - "RESET SESSION": TokenType.COMMAND, - "START": TokenType.BEGIN, - "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, - "ROW": TokenType.STRUCT, - "IPADDRESS": TokenType.IPADDRESS, - "IPPREFIX": TokenType.IPPREFIX, - "TDIGEST": TokenType.TDIGEST, - "HYPERLOGLOG": TokenType.HLLSKETCH, - } - KEYWORDS.pop("/*+") - KEYWORDS.pop("QUALIFY") - - class Parser(parser.Parser): - VALUES_FOLLOWED_BY_PAREN = False - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "ARBITRARY": exp.AnyValue.from_arg_list, - "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, - "APPROX_PERCENTILE": _build_approx_percentile, - "BITWISE_AND": binary_from_function(exp.BitwiseAnd), - "BITWISE_NOT": lambda args: exp.BitwiseNot(this=seq_get(args, 0)), - "BITWISE_OR": binary_from_function(exp.BitwiseOr), - "BITWISE_XOR": binary_from_function(exp.BitwiseXor), - "CARDINALITY": exp.ArraySize.from_arg_list, - "CONTAINS": exp.ArrayContains.from_arg_list, - "DATE_ADD": lambda args: exp.DateAdd( - this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) - ), - "DATE_DIFF": lambda args: exp.DateDiff( - this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0) - ), - "DATE_FORMAT": build_formatted_time(exp.TimeToStr, "presto"), - "DATE_PARSE": build_formatted_time(exp.StrToTime, "presto"), - "DATE_TRUNC": date_trunc_to_time, - "DAY_OF_WEEK": exp.DayOfWeekIso.from_arg_list, - "DOW": exp.DayOfWeekIso.from_arg_list, - "DOY": exp.DayOfYear.from_arg_list, - "ELEMENT_AT": lambda args: exp.Bracket( - this=seq_get(args, 0), expressions=[seq_get(args, 1)], offset=1, safe=True - ), - "FROM_HEX": exp.Unhex.from_arg_list, - "FROM_UNIXTIME": _build_from_unixtime, - "FROM_UTF8": lambda args: exp.Decode( - this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8") - ), - "JSON_FORMAT": lambda args: exp.JSONFormat( - this=seq_get(args, 0), options=seq_get(args, 1), is_json=True - ), - "LEVENSHTEIN_DISTANCE": exp.Levenshtein.from_arg_list, - "NOW": exp.CurrentTimestamp.from_arg_list, - "REGEXP_EXTRACT": build_regexp_extract(exp.RegexpExtract), - "REGEXP_EXTRACT_ALL": build_regexp_extract(exp.RegexpExtractAll), - "REGEXP_REPLACE": lambda args: exp.RegexpReplace( - this=seq_get(args, 0), - expression=seq_get(args, 1), - replacement=seq_get(args, 2) or exp.Literal.string(""), - ), - "ROW": exp.Struct.from_arg_list, - "SEQUENCE": exp.GenerateSeries.from_arg_list, - "SET_AGG": exp.ArrayUniqueAgg.from_arg_list, - "SPLIT_TO_MAP": exp.StrToMap.from_arg_list, - "STRPOS": lambda args: exp.StrPosition( - this=seq_get(args, 0), substr=seq_get(args, 1), occurrence=seq_get(args, 2) - ), - "TO_CHAR": _build_to_char, - "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, - "TO_UTF8": lambda args: exp.Encode( - this=seq_get(args, 0), charset=exp.Literal.string("utf-8") - ), - "MD5": exp.MD5Digest.from_arg_list, - "SHA256": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(256)), - "SHA512": lambda args: exp.SHA2(this=seq_get(args, 0), length=exp.Literal.number(512)), - } - - FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy() - FUNCTION_PARSERS.pop("TRIM") - - class Generator(generator.Generator): - INTERVAL_ALLOWS_PLURAL_FORM = False - JOIN_HINTS = False - TABLE_HINTS = False - QUERY_HINTS = False - IS_BOOL_ALLOWED = False - TZ_TO_WITH_TIME_ZONE = True - NVL2_SUPPORTED = False - STRUCT_DELIMITER = ("(", ")") - LIMIT_ONLY_LITERALS = True - SUPPORTS_SINGLE_ARG_CONCAT = False - LIKE_PROPERTY_INSIDE_SCHEMA = True - MULTI_ARG_DISTINCT = False - SUPPORTS_TO_NUMBER = False - HEX_FUNC = "TO_HEX" - PARSE_JSON_NAME = "JSON_PARSE" - PAD_FILL_PATTERN_IS_REQUIRED = True - EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE = False - SUPPORTS_MEDIAN = False - ARRAY_SIZE_NAME = "CARDINALITY" - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.LocationProperty: exp.Properties.Location.UNSUPPORTED, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, - } - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.BINARY: "VARBINARY", - exp.DataType.Type.BIT: "BOOLEAN", - exp.DataType.Type.DATETIME: "TIMESTAMP", - exp.DataType.Type.DATETIME64: "TIMESTAMP", - exp.DataType.Type.FLOAT: "REAL", - exp.DataType.Type.HLLSKETCH: "HYPERLOGLOG", - exp.DataType.Type.INT: "INTEGER", - exp.DataType.Type.STRUCT: "ROW", - exp.DataType.Type.TEXT: "VARCHAR", - exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", - exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP", - exp.DataType.Type.TIMETZ: "TIME", - } - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.AnyValue: rename_func("ARBITRARY"), - exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), - exp.ArgMax: rename_func("MAX_BY"), - exp.ArgMin: rename_func("MIN_BY"), - exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", - exp.ArrayAny: rename_func("ANY_MATCH"), - exp.ArrayConcat: rename_func("CONCAT"), - exp.ArrayContains: rename_func("CONTAINS"), - exp.ArrayToString: rename_func("ARRAY_JOIN"), - exp.ArrayUniqueAgg: rename_func("SET_AGG"), - exp.AtTimeZone: rename_func("AT_TIMEZONE"), - exp.BitwiseAnd: lambda self, e: self.func("BITWISE_AND", e.this, e.expression), - exp.BitwiseLeftShift: lambda self, e: self.func( - "BITWISE_ARITHMETIC_SHIFT_LEFT", e.this, e.expression - ), - exp.BitwiseNot: lambda self, e: self.func("BITWISE_NOT", e.this), - exp.BitwiseOr: lambda self, e: self.func("BITWISE_OR", e.this, e.expression), - exp.BitwiseRightShift: lambda self, e: self.func( - "BITWISE_ARITHMETIC_SHIFT_RIGHT", e.this, e.expression - ), - exp.BitwiseXor: lambda self, e: self.func("BITWISE_XOR", e.this, e.expression), - exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]), - exp.CurrentTime: lambda *_: "CURRENT_TIME", - exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.CurrentUser: lambda *_: "CURRENT_USER", - exp.DateAdd: _date_delta_sql("DATE_ADD"), - exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", unit_to_str(e), e.expression, e.this - ), - exp.DateStrToDate: datestrtodate_sql, - exp.DateToDi: lambda self, - e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.DATEINT_FORMAT}) AS INT)", - exp.DateSub: _date_delta_sql("DATE_ADD", negate_interval=True), - exp.DayOfWeek: lambda self, e: f"(({self.func('DAY_OF_WEEK', e.this)} % 7) + 1)", - exp.DayOfWeekIso: rename_func("DAY_OF_WEEK"), - exp.Decode: lambda self, e: encode_decode_sql(self, e, "FROM_UTF8"), - exp.DiToDate: lambda self, - e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.DATEINT_FORMAT}) AS DATE)", - exp.Encode: lambda self, e: encode_decode_sql(self, e, "TO_UTF8"), - exp.FileFormatProperty: lambda self, - e: f"format={self.sql(exp.Literal.string(e.name))}", - exp.First: _first_last_sql, - exp.FromTimeZone: lambda self, - e: f"WITH_TIMEZONE({self.sql(e, 'this')}, {self.sql(e, 'zone')}) AT TIME ZONE 'UTC'", - exp.GenerateSeries: sequence_sql, - exp.GenerateDateArray: sequence_sql, - exp.Group: transforms.preprocess([transforms.unalias_group]), - exp.If: if_sql(), - exp.ILike: no_ilike_sql, - exp.Initcap: _initcap_sql, - exp.Last: _first_last_sql, - exp.LastDay: lambda self, e: self.func("LAST_DAY_OF_MONTH", e.this), - exp.Lateral: _explode_to_unnest_sql, - exp.Left: left_to_substring_sql, - exp.Levenshtein: unsupported_args("ins_cost", "del_cost", "sub_cost", "max_dist")( - rename_func("LEVENSHTEIN_DISTANCE") - ), - exp.LogicalAnd: rename_func("BOOL_AND"), - exp.LogicalOr: rename_func("BOOL_OR"), - exp.Pivot: no_pivot_sql, - exp.Quantile: _quantile_sql, - exp.RegexpExtract: regexp_extract_sql, - exp.RegexpExtractAll: regexp_extract_sql, - exp.Right: right_to_substring_sql, - exp.Schema: _schema_sql, - exp.SchemaCommentProperty: lambda self, e: self.naked_property(e), - exp.Select: transforms.preprocess( - [ - transforms.eliminate_window_clause, - transforms.eliminate_qualify, - transforms.eliminate_distinct_on, - transforms.explode_projection_to_unnest(1), - transforms.eliminate_semi_and_anti_joins, - amend_exploded_column_table, - ] - ), - exp.SortArray: _no_sort_array, - exp.StrPosition: lambda self, e: strposition_sql(self, e, supports_occurrence=True), - exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", - exp.StrToMap: rename_func("SPLIT_TO_MAP"), - exp.StrToTime: _str_to_time_sql, - exp.StructExtract: struct_extract_sql, - exp.Table: transforms.preprocess([transforms.unnest_generate_series]), - exp.Timestamp: no_timestamp_sql, - exp.TimestampAdd: _date_delta_sql("DATE_ADD"), - exp.TimestampTrunc: timestamptrunc_sql(), - exp.TimeStrToDate: timestrtotime_sql, - exp.TimeStrToTime: timestrtotime_sql, - exp.TimeStrToUnix: lambda self, e: self.func( - "TO_UNIXTIME", self.func("DATE_PARSE", e.this, Presto.TIME_FORMAT) - ), - exp.TimeToStr: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), - exp.TimeToUnix: rename_func("TO_UNIXTIME"), - exp.ToChar: lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e)), - exp.TryCast: transforms.preprocess([transforms.epoch_cast_to_ts]), - exp.TsOrDiToDi: lambda self, - e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", - exp.TsOrDsAdd: _ts_or_ds_add_sql, - exp.TsOrDsDiff: _ts_or_ds_diff_sql, - exp.TsOrDsToDate: _ts_or_ds_to_date_sql, - exp.Unhex: rename_func("FROM_HEX"), - exp.UnixToStr: lambda self, - e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", - exp.UnixToTime: _unix_to_time_sql, - exp.UnixToTimeStr: lambda self, - e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", - exp.VariancePop: rename_func("VAR_POP"), - exp.With: transforms.preprocess([transforms.add_recursive_cte_column_names]), - exp.WithinGroup: transforms.preprocess( - [transforms.remove_within_group_for_percentiles] - ), - exp.Xor: bool_xor_sql, - exp.MD5Digest: rename_func("MD5"), - exp.SHA: rename_func("SHA1"), - exp.SHA2: sha256_sql, - } - - RESERVED_KEYWORDS = { - "alter", - "and", - "as", - "between", - "by", - "case", - "cast", - "constraint", - "create", - "cross", - "current_time", - "current_timestamp", - "deallocate", - "delete", - "describe", - "distinct", - "drop", - "else", - "end", - "escape", - "except", - "execute", - "exists", - "extract", - "false", - "for", - "from", - "full", - "group", - "having", - "in", - "inner", - "insert", - "intersect", - "into", - "is", - "join", - "left", - "like", - "natural", - "not", - "null", - "on", - "or", - "order", - "outer", - "prepare", - "right", - "select", - "table", - "then", - "true", - "union", - "using", - "values", - "when", - "where", - "with", - } - - def jsonformat_sql(self, expression: exp.JSONFormat) -> str: - this = expression.this - is_json = expression.args.get("is_json") - - if this and not (is_json or this.type): - from sqlglot.optimizer.annotate_types import annotate_types - - this = annotate_types(this, dialect=self.dialect) - - if not (is_json or this.is_type(exp.DataType.Type.JSON)): - this.replace(exp.cast(this, exp.DataType.Type.JSON)) - - return self.function_fallback_sql(expression) - - def md5_sql(self, expression: exp.MD5) -> str: - this = expression.this - - if not this.type: - from sqlglot.optimizer.annotate_types import annotate_types - - this = annotate_types(this, dialect=self.dialect) - - if this.is_type(*exp.DataType.TEXT_TYPES): - this = exp.Encode(this=this, charset=exp.Literal.string("utf-8")) - - return self.func("LOWER", self.func("TO_HEX", self.func("MD5", self.sql(this)))) - - def strtounix_sql(self, expression: exp.StrToUnix) -> str: - # Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one. - # To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a - # timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback, - # which seems to be using the same time mapping as Hive, as per: - # https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html - this = expression.this - value_as_text = exp.cast(this, exp.DataType.Type.TEXT) - value_as_timestamp = ( - exp.cast(this, exp.DataType.Type.TIMESTAMP) if this.is_string else this - ) - - parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression)) - - formatted_value = self.func( - "DATE_FORMAT", value_as_timestamp, self.format_time(expression) - ) - parse_with_tz = self.func( - "PARSE_DATETIME", - formatted_value, - self.format_time(expression, Hive.INVERSE_TIME_MAPPING, Hive.INVERSE_TIME_TRIE), - ) - coalesced = self.func("COALESCE", self.func("TRY", parse_without_tz), parse_with_tz) - return self.func("TO_UNIXTIME", coalesced) - - def bracket_sql(self, expression: exp.Bracket) -> str: - if expression.args.get("safe"): - return self.func( - "ELEMENT_AT", - expression.this, - seq_get( - apply_index_offset( - expression.this, - expression.expressions, - 1 - expression.args.get("offset", 0), - dialect=self.dialect, - ), - 0, - ), - ) - return super().bracket_sql(expression) - - def struct_sql(self, expression: exp.Struct) -> str: - from sqlglot.optimizer.annotate_types import annotate_types - - expression = annotate_types(expression, dialect=self.dialect) - values: t.List[str] = [] - schema: t.List[str] = [] - unknown_type = False - - for e in expression.expressions: - if isinstance(e, exp.PropertyEQ): - if e.type and e.type.is_type(exp.DataType.Type.UNKNOWN): - unknown_type = True - else: - schema.append(f"{self.sql(e, 'this')} {self.sql(e.type)}") - values.append(self.sql(e, "expression")) - else: - values.append(self.sql(e)) - - size = len(expression.expressions) - - if not size or len(schema) != size: - if unknown_type: - self.unsupported( - "Cannot convert untyped key-value definitions (try annotate_types)." - ) - return self.func("ROW", *values) - return f"CAST(ROW({', '.join(values)}) AS ROW({', '.join(schema)}))" - - def interval_sql(self, expression: exp.Interval) -> str: - if expression.this and expression.text("unit").upper().startswith("WEEK"): - return f"({expression.this.name} * INTERVAL '7' DAY)" - return super().interval_sql(expression) - - def transaction_sql(self, expression: exp.Transaction) -> str: - modes = expression.args.get("modes") - modes = f" {', '.join(modes)}" if modes else "" - return f"START TRANSACTION{modes}" - - def offset_limit_modifiers( - self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit] - ) -> t.List[str]: - return [ - self.sql(expression, "offset"), - self.sql(limit), - ] - - def create_sql(self, expression: exp.Create) -> str: - """ - Presto doesn't support CREATE VIEW with expressions (ex: `CREATE VIEW x (cola)` then `(cola)` is the expression), - so we need to remove them - """ - kind = expression.args["kind"] - schema = expression.this - if kind == "VIEW" and schema.expressions: - expression.this.set("expressions", None) - return super().create_sql(expression) - - def delete_sql(self, expression: exp.Delete) -> str: - """ - Presto only supports DELETE FROM for a single table without an alias, so we need - to remove the unnecessary parts. If the original DELETE statement contains more - than one table to be deleted, we can't safely map it 1-1 to a Presto statement. - """ - tables = expression.args.get("tables") or [expression.this] - if len(tables) > 1: - return super().delete_sql(expression) - - table = tables[0] - expression.set("this", table) - expression.set("tables", None) - - if isinstance(table, exp.Table): - table_alias = table.args.get("alias") - if table_alias: - table_alias.pop() - expression = t.cast(exp.Delete, expression.transform(unqualify_columns)) - - return super().delete_sql(expression) - - def jsonextract_sql(self, expression: exp.JSONExtract) -> str: - is_json_extract = self.dialect.settings.get("variant_extract_is_json_extract", True) - - # Generate JSON_EXTRACT unless the user has configured that a Snowflake / Databricks - # VARIANT extract (e.g. col:x.y) should map to dot notation (i.e ROW access) in Presto/Trino - if not expression.args.get("variant_extract") or is_json_extract: - return self.func( - "JSON_EXTRACT", expression.this, expression.expression, *expression.expressions - ) - - this = self.sql(expression, "this") - - # Convert the JSONPath extraction `JSON_EXTRACT(col, '$.x.y) to a ROW access col.x.y - segments = [] - for path_key in expression.expression.expressions[1:]: - if not isinstance(path_key, exp.JSONPathKey): - # Cannot transpile subscripts, wildcards etc to dot notation - self.unsupported( - f"Cannot transpile JSONPath segment '{path_key}' to ROW access" - ) - continue - key = path_key.this - if not exp.SAFE_IDENTIFIER_RE.match(key): - key = f'"{key}"' - segments.append(f".{key}") - - expr = "".join(segments) - - return f"{this}{expr}" - - def groupconcat_sql(self, expression: exp.GroupConcat) -> str: - return self.func( - "ARRAY_JOIN", - self.func("ARRAY_AGG", expression.this), - expression.args.get("separator"), - ) diff --git a/altimate_packages/sqlglot/dialects/prql.py b/altimate_packages/sqlglot/dialects/prql.py deleted file mode 100644 index 022abbca4..000000000 --- a/altimate_packages/sqlglot/dialects/prql.py +++ /dev/null @@ -1,203 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp, parser, tokens -from sqlglot.dialects.dialect import Dialect -from sqlglot.helper import seq_get -from sqlglot.tokens import TokenType - - -def _select_all(table: exp.Expression) -> t.Optional[exp.Select]: - return exp.select("*").from_(table, copy=False) if table else None - - -class PRQL(Dialect): - DPIPE_IS_STRING_CONCAT = False - - class Tokenizer(tokens.Tokenizer): - IDENTIFIERS = ["`"] - QUOTES = ["'", '"'] - - SINGLE_TOKENS = { - **tokens.Tokenizer.SINGLE_TOKENS, - "=": TokenType.ALIAS, - "'": TokenType.QUOTE, - '"': TokenType.QUOTE, - "`": TokenType.IDENTIFIER, - "#": TokenType.COMMENT, - } - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - } - - class Parser(parser.Parser): - CONJUNCTION = { - **parser.Parser.CONJUNCTION, - TokenType.DAMP: exp.And, - } - - DISJUNCTION = { - **parser.Parser.DISJUNCTION, - TokenType.DPIPE: exp.Or, - } - - TRANSFORM_PARSERS = { - "DERIVE": lambda self, query: self._parse_selection(query), - "SELECT": lambda self, query: self._parse_selection(query, append=False), - "TAKE": lambda self, query: self._parse_take(query), - "FILTER": lambda self, query: query.where(self._parse_disjunction()), - "APPEND": lambda self, query: query.union( - _select_all(self._parse_table()), distinct=False, copy=False - ), - "REMOVE": lambda self, query: query.except_( - _select_all(self._parse_table()), distinct=False, copy=False - ), - "INTERSECT": lambda self, query: query.intersect( - _select_all(self._parse_table()), distinct=False, copy=False - ), - "SORT": lambda self, query: self._parse_order_by(query), - "AGGREGATE": lambda self, query: self._parse_selection( - query, parse_method=self._parse_aggregate, append=False - ), - } - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "AVERAGE": exp.Avg.from_arg_list, - "SUM": lambda args: exp.func("COALESCE", exp.Sum(this=seq_get(args, 0)), 0), - } - - def _parse_equality(self) -> t.Optional[exp.Expression]: - eq = self._parse_tokens(self._parse_comparison, self.EQUALITY) - if not isinstance(eq, (exp.EQ, exp.NEQ)): - return eq - - # https://prql-lang.org/book/reference/spec/null.html - if isinstance(eq.expression, exp.Null): - is_exp = exp.Is(this=eq.this, expression=eq.expression) - return is_exp if isinstance(eq, exp.EQ) else exp.Not(this=is_exp) - if isinstance(eq.this, exp.Null): - is_exp = exp.Is(this=eq.expression, expression=eq.this) - return is_exp if isinstance(eq, exp.EQ) else exp.Not(this=is_exp) - return eq - - def _parse_statement(self) -> t.Optional[exp.Expression]: - expression = self._parse_expression() - expression = expression if expression else self._parse_query() - return expression - - def _parse_query(self) -> t.Optional[exp.Query]: - from_ = self._parse_from() - - if not from_: - return None - - query = exp.select("*").from_(from_, copy=False) - - while self._match_texts(self.TRANSFORM_PARSERS): - query = self.TRANSFORM_PARSERS[self._prev.text.upper()](self, query) - - return query - - def _parse_selection( - self, - query: exp.Query, - parse_method: t.Optional[t.Callable] = None, - append: bool = True, - ) -> exp.Query: - parse_method = parse_method if parse_method else self._parse_expression - if self._match(TokenType.L_BRACE): - selects = self._parse_csv(parse_method) - - if not self._match(TokenType.R_BRACE, expression=query): - self.raise_error("Expecting }") - else: - expression = parse_method() - selects = [expression] if expression else [] - - projections = { - select.alias_or_name: select.this if isinstance(select, exp.Alias) else select - for select in query.selects - } - - selects = [ - select.transform( - lambda s: (projections[s.name].copy() if s.name in projections else s) - if isinstance(s, exp.Column) - else s, - copy=False, - ) - for select in selects - ] - - return query.select(*selects, append=append, copy=False) - - def _parse_take(self, query: exp.Query) -> t.Optional[exp.Query]: - num = self._parse_number() # TODO: TAKE for ranges a..b - return query.limit(num) if num else None - - def _parse_ordered( - self, parse_method: t.Optional[t.Callable] = None - ) -> t.Optional[exp.Ordered]: - asc = self._match(TokenType.PLUS) - desc = self._match(TokenType.DASH) or (asc and False) - term = term = super()._parse_ordered(parse_method=parse_method) - if term and desc: - term.set("desc", True) - term.set("nulls_first", False) - return term - - def _parse_order_by(self, query: exp.Select) -> t.Optional[exp.Query]: - l_brace = self._match(TokenType.L_BRACE) - expressions = self._parse_csv(self._parse_ordered) - if l_brace and not self._match(TokenType.R_BRACE): - self.raise_error("Expecting }") - return query.order_by(self.expression(exp.Order, expressions=expressions), copy=False) - - def _parse_aggregate(self) -> t.Optional[exp.Expression]: - alias = None - if self._next and self._next.token_type == TokenType.ALIAS: - alias = self._parse_id_var(any_token=True) - self._match(TokenType.ALIAS) - - name = self._curr and self._curr.text.upper() - func_builder = self.FUNCTIONS.get(name) - if func_builder: - self._advance() - args = self._parse_column() - func = func_builder([args]) - else: - self.raise_error(f"Unsupported aggregation function {name}") - if alias: - return self.expression(exp.Alias, this=func, alias=alias) - return func - - def _parse_expression(self) -> t.Optional[exp.Expression]: - if self._next and self._next.token_type == TokenType.ALIAS: - alias = self._parse_id_var(True) - self._match(TokenType.ALIAS) - return self.expression(exp.Alias, this=self._parse_assignment(), alias=alias) - return self._parse_assignment() - - def _parse_table( - self, - schema: bool = False, - joins: bool = False, - alias_tokens: t.Optional[t.Collection[TokenType]] = None, - parse_bracket: bool = False, - is_db_reference: bool = False, - parse_partition: bool = False, - ) -> t.Optional[exp.Expression]: - return self._parse_table_parts() - - def _parse_from( - self, joins: bool = False, skip_from_token: bool = False - ) -> t.Optional[exp.From]: - if not skip_from_token and not self._match(TokenType.FROM): - return None - - return self.expression( - exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins) - ) diff --git a/altimate_packages/sqlglot/dialects/redshift.py b/altimate_packages/sqlglot/dialects/redshift.py deleted file mode 100644 index 27998fc89..000000000 --- a/altimate_packages/sqlglot/dialects/redshift.py +++ /dev/null @@ -1,448 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp, transforms -from sqlglot.dialects.dialect import ( - NormalizationStrategy, - concat_to_dpipe_sql, - concat_ws_to_dpipe_sql, - date_delta_sql, - generatedasidentitycolumnconstraint_sql, - json_extract_segments, - no_tablesample_sql, - rename_func, - map_date_part, -) -from sqlglot.dialects.postgres import Postgres -from sqlglot.helper import seq_get -from sqlglot.tokens import TokenType -from sqlglot.parser import build_convert_timezone - -if t.TYPE_CHECKING: - from sqlglot._typing import E - - -def _build_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]: - def _builder(args: t.List) -> E: - expr = expr_type( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=map_date_part(seq_get(args, 0)), - ) - if expr_type is exp.TsOrDsAdd: - expr.set("return_type", exp.DataType.build("TIMESTAMP")) - - return expr - - return _builder - - -class Redshift(Postgres): - # https://docs.aws.amazon.com/redshift/latest/dg/r_names.html - NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE - - SUPPORTS_USER_DEFINED_TYPES = False - INDEX_OFFSET = 0 - COPY_PARAMS_ARE_CSV = False - HEX_LOWERCASE = True - HAS_DISTINCT_ARRAY_CONSTRUCTORS = True - - # ref: https://docs.aws.amazon.com/redshift/latest/dg/r_FORMAT_strings.html - TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" - TIME_MAPPING = {**Postgres.TIME_MAPPING, "MON": "%b", "HH24": "%H", "HH": "%I"} - - class Parser(Postgres.Parser): - FUNCTIONS = { - **Postgres.Parser.FUNCTIONS, - "ADD_MONTHS": lambda args: exp.TsOrDsAdd( - this=seq_get(args, 0), - expression=seq_get(args, 1), - unit=exp.var("month"), - return_type=exp.DataType.build("TIMESTAMP"), - ), - "CONVERT_TIMEZONE": lambda args: build_convert_timezone(args, "UTC"), - "DATEADD": _build_date_delta(exp.TsOrDsAdd), - "DATE_ADD": _build_date_delta(exp.TsOrDsAdd), - "DATEDIFF": _build_date_delta(exp.TsOrDsDiff), - "DATE_DIFF": _build_date_delta(exp.TsOrDsDiff), - "GETDATE": exp.CurrentTimestamp.from_arg_list, - "LISTAGG": exp.GroupConcat.from_arg_list, - "SPLIT_TO_ARRAY": lambda args: exp.StringToArray( - this=seq_get(args, 0), expression=seq_get(args, 1) or exp.Literal.string(",") - ), - "STRTOL": exp.FromBase.from_arg_list, - } - - NO_PAREN_FUNCTION_PARSERS = { - **Postgres.Parser.NO_PAREN_FUNCTION_PARSERS, - "APPROXIMATE": lambda self: self._parse_approximate_count(), - "SYSDATE": lambda self: self.expression(exp.CurrentTimestamp, sysdate=True), - } - - SUPPORTS_IMPLICIT_UNNEST = True - - def _parse_table( - self, - schema: bool = False, - joins: bool = False, - alias_tokens: t.Optional[t.Collection[TokenType]] = None, - parse_bracket: bool = False, - is_db_reference: bool = False, - parse_partition: bool = False, - ) -> t.Optional[exp.Expression]: - # Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr` - unpivot = self._match(TokenType.UNPIVOT) - table = super()._parse_table( - schema=schema, - joins=joins, - alias_tokens=alias_tokens, - parse_bracket=parse_bracket, - is_db_reference=is_db_reference, - ) - - return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table - - def _parse_convert( - self, strict: bool, safe: t.Optional[bool] = None - ) -> t.Optional[exp.Expression]: - to = self._parse_types() - self._match(TokenType.COMMA) - this = self._parse_bitwise() - return self.expression(exp.TryCast, this=this, to=to, safe=safe) - - def _parse_approximate_count(self) -> t.Optional[exp.ApproxDistinct]: - index = self._index - 1 - func = self._parse_function() - - if isinstance(func, exp.Count) and isinstance(func.this, exp.Distinct): - return self.expression(exp.ApproxDistinct, this=seq_get(func.this.expressions, 0)) - self._retreat(index) - return None - - class Tokenizer(Postgres.Tokenizer): - BIT_STRINGS = [] - HEX_STRINGS = [] - STRING_ESCAPES = ["\\", "'"] - - KEYWORDS = { - **Postgres.Tokenizer.KEYWORDS, - "(+)": TokenType.JOIN_MARKER, - "HLLSKETCH": TokenType.HLLSKETCH, - "MINUS": TokenType.EXCEPT, - "SUPER": TokenType.SUPER, - "TOP": TokenType.TOP, - "UNLOAD": TokenType.COMMAND, - "VARBYTE": TokenType.VARBINARY, - "BINARY VARYING": TokenType.VARBINARY, - } - KEYWORDS.pop("VALUES") - - # Redshift allows # to appear as a table identifier prefix - SINGLE_TOKENS = Postgres.Tokenizer.SINGLE_TOKENS.copy() - SINGLE_TOKENS.pop("#") - - class Generator(Postgres.Generator): - LOCKING_READS_SUPPORTED = False - QUERY_HINTS = False - VALUES_AS_TABLE = False - TZ_TO_WITH_TIME_ZONE = True - NVL2_SUPPORTED = True - LAST_DAY_SUPPORTS_DATE_PART = False - CAN_IMPLEMENT_ARRAY_ANY = False - MULTI_ARG_DISTINCT = True - COPY_PARAMS_ARE_WRAPPED = False - HEX_FUNC = "TO_HEX" - PARSE_JSON_NAME = "JSON_PARSE" - ARRAY_CONCAT_IS_VAR_LEN = False - SUPPORTS_CONVERT_TIMEZONE = True - EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE = False - SUPPORTS_MEDIAN = True - ALTER_SET_TYPE = "TYPE" - - # Redshift doesn't have `WITH` as part of their with_properties so we remove it - WITH_PROPERTIES_PREFIX = " " - - TYPE_MAPPING = { - **Postgres.Generator.TYPE_MAPPING, - exp.DataType.Type.BINARY: "VARBYTE", - exp.DataType.Type.BLOB: "VARBYTE", - exp.DataType.Type.INT: "INTEGER", - exp.DataType.Type.TIMETZ: "TIME", - exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", - exp.DataType.Type.VARBINARY: "VARBYTE", - exp.DataType.Type.ROWVERSION: "VARBYTE", - } - - TRANSFORMS = { - **Postgres.Generator.TRANSFORMS, - exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CONCAT"), - exp.Concat: concat_to_dpipe_sql, - exp.ConcatWs: concat_ws_to_dpipe_sql, - exp.ApproxDistinct: lambda self, - e: f"APPROXIMATE COUNT(DISTINCT {self.sql(e, 'this')})", - exp.CurrentTimestamp: lambda self, e: ( - "SYSDATE" if e.args.get("sysdate") else "GETDATE()" - ), - exp.DateAdd: date_delta_sql("DATEADD"), - exp.DateDiff: date_delta_sql("DATEDIFF"), - exp.DistKeyProperty: lambda self, e: self.func("DISTKEY", e.this), - exp.DistStyleProperty: lambda self, e: self.naked_property(e), - exp.Explode: lambda self, e: self.explode_sql(e), - exp.FromBase: rename_func("STRTOL"), - exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, - exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH_TEXT"), - exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"), - exp.GroupConcat: rename_func("LISTAGG"), - exp.Hex: lambda self, e: self.func("UPPER", self.func("TO_HEX", self.sql(e, "this"))), - exp.Select: transforms.preprocess( - [ - transforms.eliminate_window_clause, - transforms.eliminate_distinct_on, - transforms.eliminate_semi_and_anti_joins, - transforms.unqualify_unnest, - transforms.unnest_generate_date_array_using_recursive_cte, - ] - ), - exp.SortKeyProperty: lambda self, - e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})", - exp.StartsWith: lambda self, - e: f"{self.sql(e.this)} LIKE {self.sql(e.expression)} || '%'", - exp.StringToArray: rename_func("SPLIT_TO_ARRAY"), - exp.TableSample: no_tablesample_sql, - exp.TsOrDsAdd: date_delta_sql("DATEADD"), - exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), - exp.UnixToTime: lambda self, - e: f"(TIMESTAMP 'epoch' + {self.sql(e.this)} * INTERVAL '1 SECOND')", - } - - # Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots - TRANSFORMS.pop(exp.Pivot) - - # Postgres doesn't support JSON_PARSE, but Redshift does - TRANSFORMS.pop(exp.ParseJSON) - - # Redshift supports these functions - TRANSFORMS.pop(exp.AnyValue) - TRANSFORMS.pop(exp.LastDay) - TRANSFORMS.pop(exp.SHA2) - - RESERVED_KEYWORDS = { - "aes128", - "aes256", - "all", - "allowoverwrite", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "authorization", - "az64", - "backup", - "between", - "binary", - "blanksasnull", - "both", - "bytedict", - "bzip2", - "case", - "cast", - "check", - "collate", - "column", - "constraint", - "create", - "credentials", - "cross", - "current_date", - "current_time", - "current_timestamp", - "current_user", - "current_user_id", - "default", - "deferrable", - "deflate", - "defrag", - "delta", - "delta32k", - "desc", - "disable", - "distinct", - "do", - "else", - "emptyasnull", - "enable", - "encode", - "encrypt ", - "encryption", - "end", - "except", - "explicit", - "false", - "for", - "foreign", - "freeze", - "from", - "full", - "globaldict256", - "globaldict64k", - "grant", - "group", - "gzip", - "having", - "identity", - "ignore", - "ilike", - "in", - "initially", - "inner", - "intersect", - "interval", - "into", - "is", - "isnull", - "join", - "leading", - "left", - "like", - "limit", - "localtime", - "localtimestamp", - "lun", - "luns", - "lzo", - "lzop", - "minus", - "mostly16", - "mostly32", - "mostly8", - "natural", - "new", - "not", - "notnull", - "null", - "nulls", - "off", - "offline", - "offset", - "oid", - "old", - "on", - "only", - "open", - "or", - "order", - "outer", - "overlaps", - "parallel", - "partition", - "percent", - "permissions", - "pivot", - "placing", - "primary", - "raw", - "readratio", - "recover", - "references", - "rejectlog", - "resort", - "respect", - "restore", - "right", - "select", - "session_user", - "similar", - "snapshot", - "some", - "sysdate", - "system", - "table", - "tag", - "tdes", - "text255", - "text32k", - "then", - "timestamp", - "to", - "top", - "trailing", - "true", - "truncatecolumns", - "type", - "union", - "unique", - "unnest", - "unpivot", - "user", - "using", - "verbose", - "wallet", - "when", - "where", - "with", - "without", - } - - def unnest_sql(self, expression: exp.Unnest) -> str: - args = expression.expressions - num_args = len(args) - - if num_args != 1: - self.unsupported(f"Unsupported number of arguments in UNNEST: {num_args}") - return "" - - if isinstance(expression.find_ancestor(exp.From, exp.Join, exp.Select), exp.Select): - self.unsupported("Unsupported UNNEST when not used in FROM/JOIN clauses") - return "" - - arg = self.sql(seq_get(args, 0)) - - alias = self.expressions(expression.args.get("alias"), key="columns", flat=True) - return f"{arg} AS {alias}" if alias else arg - - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - if expression.is_type(exp.DataType.Type.JSON): - # Redshift doesn't support a JSON type, so casting to it is treated as a noop - return self.sql(expression, "this") - - return super().cast_sql(expression, safe_prefix=safe_prefix) - - def datatype_sql(self, expression: exp.DataType) -> str: - """ - Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean - VARCHAR of max length which is `VARCHAR(max)` in Redshift. Therefore if we get a `TEXT` data type - without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert - `TEXT` to `VARCHAR`. - """ - if expression.is_type("text"): - expression.set("this", exp.DataType.Type.VARCHAR) - precision = expression.args.get("expressions") - - if not precision: - expression.append("expressions", exp.var("MAX")) - - return super().datatype_sql(expression) - - def alterset_sql(self, expression: exp.AlterSet) -> str: - exprs = self.expressions(expression, flat=True) - exprs = f" TABLE PROPERTIES ({exprs})" if exprs else "" - location = self.sql(expression, "location") - location = f" LOCATION {location}" if location else "" - file_format = self.expressions(expression, key="file_format", flat=True, sep=" ") - file_format = f" FILE FORMAT {file_format}" if file_format else "" - - return f"SET{exprs}{location}{file_format}" - - def array_sql(self, expression: exp.Array) -> str: - if expression.args.get("bracket_notation"): - return super().array_sql(expression) - - return rename_func("ARRAY")(self, expression) - - def explode_sql(self, expression: exp.Explode) -> str: - self.unsupported("Unsupported EXPLODE() function") - return "" diff --git a/altimate_packages/sqlglot/dialects/risingwave.py b/altimate_packages/sqlglot/dialects/risingwave.py deleted file mode 100644 index 7a1775d38..000000000 --- a/altimate_packages/sqlglot/dialects/risingwave.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations -from sqlglot.dialects.postgres import Postgres -from sqlglot.generator import Generator -from sqlglot.tokens import TokenType -import typing as t - -from sqlglot import exp - - -class RisingWave(Postgres): - class Tokenizer(Postgres.Tokenizer): - KEYWORDS = { - **Postgres.Tokenizer.KEYWORDS, - "SINK": TokenType.SINK, - "SOURCE": TokenType.SOURCE, - } - - class Parser(Postgres.Parser): - WRAPPED_TRANSFORM_COLUMN_CONSTRAINT = False - - PROPERTY_PARSERS = { - **Postgres.Parser.PROPERTY_PARSERS, - "ENCODE": lambda self: self._parse_encode_property(), - "INCLUDE": lambda self: self._parse_include_property(), - "KEY": lambda self: self._parse_encode_property(key=True), - } - - def _parse_table_hints(self) -> t.Optional[t.List[exp.Expression]]: - # There is no hint in risingwave. - # Do nothing here to avoid WITH keywords conflict in CREATE SINK statement. - return None - - def _parse_include_property(self) -> t.Optional[exp.Expression]: - header: t.Optional[exp.Expression] = None - coldef: t.Optional[exp.Expression] = None - - this = self._parse_var_or_string() - - if not self._match(TokenType.ALIAS): - header = self._parse_field() - if header: - coldef = self.expression(exp.ColumnDef, this=header, kind=self._parse_types()) - - self._match(TokenType.ALIAS) - alias = self._parse_id_var(tokens=self.ALIAS_TOKENS) - - return self.expression(exp.IncludeProperty, this=this, alias=alias, column_def=coldef) - - def _parse_encode_property(self, key: t.Optional[bool] = None) -> exp.EncodeProperty: - self._match_text_seq("ENCODE") - this = self._parse_var_or_string() - - if self._match(TokenType.L_PAREN, advance=False): - properties = self.expression( - exp.Properties, expressions=self._parse_wrapped_properties() - ) - else: - properties = None - - return self.expression(exp.EncodeProperty, this=this, properties=properties, key=key) - - class Generator(Postgres.Generator): - LOCKING_READS_SUPPORTED = False - - TRANSFORMS = { - **Postgres.Generator.TRANSFORMS, - exp.FileFormatProperty: lambda self, e: f"FORMAT {self.sql(e, 'this')}", - } - - PROPERTIES_LOCATION = { - **Postgres.Generator.PROPERTIES_LOCATION, - exp.FileFormatProperty: exp.Properties.Location.POST_EXPRESSION, - } - - EXPRESSION_PRECEDES_PROPERTIES_CREATABLES = {"SINK"} - - def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str: - return Generator.computedcolumnconstraint_sql(self, expression) diff --git a/altimate_packages/sqlglot/dialects/snowflake.py b/altimate_packages/sqlglot/dialects/snowflake.py deleted file mode 100644 index a978d254f..000000000 --- a/altimate_packages/sqlglot/dialects/snowflake.py +++ /dev/null @@ -1,1464 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp, generator, jsonpath, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - Dialect, - NormalizationStrategy, - build_timetostr_or_tochar, - binary_from_function, - build_default_decimal_type, - build_timestamp_from_parts, - date_delta_sql, - date_trunc_to_time, - datestrtodate_sql, - build_formatted_time, - if_sql, - inline_array_sql, - max_or_greatest, - min_or_least, - rename_func, - timestamptrunc_sql, - timestrtotime_sql, - var_map_sql, - map_date_part, - no_timestamp_sql, - strposition_sql, - timestampdiff_sql, - no_make_interval_sql, - groupconcat_sql, -) -from sqlglot.generator import unsupported_args -from sqlglot.helper import flatten, is_float, is_int, seq_get -from sqlglot.tokens import TokenType - -if t.TYPE_CHECKING: - from sqlglot._typing import E, B - - -# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html -def _build_datetime( - name: str, kind: exp.DataType.Type, safe: bool = False -) -> t.Callable[[t.List], exp.Func]: - def _builder(args: t.List) -> exp.Func: - value = seq_get(args, 0) - scale_or_fmt = seq_get(args, 1) - - int_value = value is not None and is_int(value.name) - int_scale_or_fmt = scale_or_fmt is not None and scale_or_fmt.is_int - - if isinstance(value, exp.Literal) or (value and scale_or_fmt): - # Converts calls like `TO_TIME('01:02:03')` into casts - if len(args) == 1 and value.is_string and not int_value: - return ( - exp.TryCast(this=value, to=exp.DataType.build(kind)) - if safe - else exp.cast(value, kind) - ) - - # Handles `TO_TIMESTAMP(str, fmt)` and `TO_TIMESTAMP(num, scale)` as special - # cases so we can transpile them, since they're relatively common - if kind == exp.DataType.Type.TIMESTAMP: - if not safe and (int_value or int_scale_or_fmt): - # TRY_TO_TIMESTAMP('integer') is not parsed into exp.UnixToTime as - # it's not easily transpilable - return exp.UnixToTime(this=value, scale=scale_or_fmt) - if not int_scale_or_fmt and not is_float(value.name): - expr = build_formatted_time(exp.StrToTime, "snowflake")(args) - expr.set("safe", safe) - return expr - - if kind in (exp.DataType.Type.DATE, exp.DataType.Type.TIME) and not int_value: - klass = exp.TsOrDsToDate if kind == exp.DataType.Type.DATE else exp.TsOrDsToTime - formatted_exp = build_formatted_time(klass, "snowflake")(args) - formatted_exp.set("safe", safe) - return formatted_exp - - return exp.Anonymous(this=name, expressions=args) - - return _builder - - -def _build_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]: - expression = parser.build_var_map(args) - - if isinstance(expression, exp.StarMap): - return expression - - return exp.Struct( - expressions=[ - exp.PropertyEQ(this=k, expression=v) for k, v in zip(expression.keys, expression.values) - ] - ) - - -def _build_datediff(args: t.List) -> exp.DateDiff: - return exp.DateDiff( - this=seq_get(args, 2), expression=seq_get(args, 1), unit=map_date_part(seq_get(args, 0)) - ) - - -def _build_date_time_add(expr_type: t.Type[E]) -> t.Callable[[t.List], E]: - def _builder(args: t.List) -> E: - return expr_type( - this=seq_get(args, 2), - expression=seq_get(args, 1), - unit=map_date_part(seq_get(args, 0)), - ) - - return _builder - - -def _build_bitwise(expr_type: t.Type[B], name: str) -> t.Callable[[t.List], B | exp.Anonymous]: - def _builder(args: t.List) -> B | exp.Anonymous: - if len(args) == 3: - return exp.Anonymous(this=name, expressions=args) - - return binary_from_function(expr_type)(args) - - return _builder - - -# https://docs.snowflake.com/en/sql-reference/functions/div0 -def _build_if_from_div0(args: t.List) -> exp.If: - lhs = exp._wrap(seq_get(args, 0), exp.Binary) - rhs = exp._wrap(seq_get(args, 1), exp.Binary) - - cond = exp.EQ(this=rhs, expression=exp.Literal.number(0)).and_( - exp.Is(this=lhs, expression=exp.null()).not_() - ) - true = exp.Literal.number(0) - false = exp.Div(this=lhs, expression=rhs) - return exp.If(this=cond, true=true, false=false) - - -# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _build_if_from_zeroifnull(args: t.List) -> exp.If: - cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) - return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0)) - - -# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull -def _build_if_from_nullifzero(args: t.List) -> exp.If: - cond = exp.EQ(this=seq_get(args, 0), expression=exp.Literal.number(0)) - return exp.If(this=cond, true=exp.Null(), false=seq_get(args, 0)) - - -def _regexpilike_sql(self: Snowflake.Generator, expression: exp.RegexpILike) -> str: - flag = expression.text("flag") - - if "i" not in flag: - flag += "i" - - return self.func( - "REGEXP_LIKE", expression.this, expression.expression, exp.Literal.string(flag) - ) - - -def _build_regexp_replace(args: t.List) -> exp.RegexpReplace: - regexp_replace = exp.RegexpReplace.from_arg_list(args) - - if not regexp_replace.args.get("replacement"): - regexp_replace.set("replacement", exp.Literal.string("")) - - return regexp_replace - - -def _show_parser(*args: t.Any, **kwargs: t.Any) -> t.Callable[[Snowflake.Parser], exp.Show]: - def _parse(self: Snowflake.Parser) -> exp.Show: - return self._parse_show_snowflake(*args, **kwargs) - - return _parse - - -def _date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: - trunc = date_trunc_to_time(args) - trunc.set("unit", map_date_part(trunc.args["unit"])) - return trunc - - -def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: - """ - Snowflake doesn't allow columns referenced in UNPIVOT to be qualified, - so we need to unqualify them. Same goes for ANY ORDER BY . - - Example: - >>> from sqlglot import parse_one - >>> expr = parse_one("SELECT * FROM m_sales UNPIVOT(sales FOR month IN (m_sales.jan, feb, mar, april))") - >>> print(_unqualify_pivot_columns(expr).sql(dialect="snowflake")) - SELECT * FROM m_sales UNPIVOT(sales FOR month IN (jan, feb, mar, april)) - """ - if isinstance(expression, exp.Pivot): - if expression.unpivot: - expression = transforms.unqualify_columns(expression) - else: - for field in expression.fields: - field_expr = seq_get(field.expressions if field else [], 0) - - if isinstance(field_expr, exp.PivotAny): - unqualified_field_expr = transforms.unqualify_columns(field_expr) - t.cast(exp.Expression, field).set("expressions", unqualified_field_expr, 0) - - return expression - - -def _flatten_structured_types_unless_iceberg(expression: exp.Expression) -> exp.Expression: - assert isinstance(expression, exp.Create) - - def _flatten_structured_type(expression: exp.DataType) -> exp.DataType: - if expression.this in exp.DataType.NESTED_TYPES: - expression.set("expressions", None) - return expression - - props = expression.args.get("properties") - if isinstance(expression.this, exp.Schema) and not (props and props.find(exp.IcebergProperty)): - for schema_expression in expression.this.expressions: - if isinstance(schema_expression, exp.ColumnDef): - column_type = schema_expression.kind - if isinstance(column_type, exp.DataType): - column_type.transform(_flatten_structured_type, copy=False) - - return expression - - -def _unnest_generate_date_array(unnest: exp.Unnest) -> None: - generate_date_array = unnest.expressions[0] - start = generate_date_array.args.get("start") - end = generate_date_array.args.get("end") - step = generate_date_array.args.get("step") - - if not start or not end or not isinstance(step, exp.Interval) or step.name != "1": - return - - unit = step.args.get("unit") - - unnest_alias = unnest.args.get("alias") - if unnest_alias: - unnest_alias = unnest_alias.copy() - sequence_value_name = seq_get(unnest_alias.columns, 0) or "value" - else: - sequence_value_name = "value" - - # We'll add the next sequence value to the starting date and project the result - date_add = _build_date_time_add(exp.DateAdd)( - [unit, exp.cast(sequence_value_name, "int"), exp.cast(start, "date")] - ).as_(sequence_value_name) - - # We use DATEDIFF to compute the number of sequence values needed - number_sequence = Snowflake.Parser.FUNCTIONS["ARRAY_GENERATE_RANGE"]( - [exp.Literal.number(0), _build_datediff([unit, start, end]) + 1] - ) - - unnest.set("expressions", [number_sequence]) - unnest.replace(exp.select(date_add).from_(unnest.copy()).subquery(unnest_alias)) - - -def _transform_generate_date_array(expression: exp.Expression) -> exp.Expression: - if isinstance(expression, exp.Select): - for generate_date_array in expression.find_all(exp.GenerateDateArray): - parent = generate_date_array.parent - - # If GENERATE_DATE_ARRAY is used directly as an array (e.g passed into ARRAY_LENGTH), the transformed Snowflake - # query is the following (it'll be unnested properly on the next iteration due to copy): - # SELECT ref(GENERATE_DATE_ARRAY(...)) -> SELECT ref((SELECT ARRAY_AGG(*) FROM UNNEST(GENERATE_DATE_ARRAY(...)))) - if not isinstance(parent, exp.Unnest): - unnest = exp.Unnest(expressions=[generate_date_array.copy()]) - generate_date_array.replace( - exp.select(exp.ArrayAgg(this=exp.Star())).from_(unnest).subquery() - ) - - if ( - isinstance(parent, exp.Unnest) - and isinstance(parent.parent, (exp.From, exp.Join)) - and len(parent.expressions) == 1 - ): - _unnest_generate_date_array(parent) - - return expression - - -def _build_regexp_extract(expr_type: t.Type[E]) -> t.Callable[[t.List], E]: - def _builder(args: t.List) -> E: - return expr_type( - this=seq_get(args, 0), - expression=seq_get(args, 1), - position=seq_get(args, 2), - occurrence=seq_get(args, 3), - parameters=seq_get(args, 4), - group=seq_get(args, 5) or exp.Literal.number(0), - ) - - return _builder - - -def _regexpextract_sql(self, expression: exp.RegexpExtract | exp.RegexpExtractAll) -> str: - # Other dialects don't support all of the following parameters, so we need to - # generate default values as necessary to ensure the transpilation is correct - group = expression.args.get("group") - - # To avoid generating all these default values, we set group to None if - # it's 0 (also default value) which doesn't trigger the following chain - if group and group.name == "0": - group = None - - parameters = expression.args.get("parameters") or (group and exp.Literal.string("c")) - occurrence = expression.args.get("occurrence") or (parameters and exp.Literal.number(1)) - position = expression.args.get("position") or (occurrence and exp.Literal.number(1)) - - return self.func( - "REGEXP_SUBSTR" if isinstance(expression, exp.RegexpExtract) else "REGEXP_EXTRACT_ALL", - expression.this, - expression.expression, - position, - occurrence, - parameters, - group, - ) - - -def _json_extract_value_array_sql( - self: Snowflake.Generator, expression: exp.JSONValueArray | exp.JSONExtractArray -) -> str: - json_extract = exp.JSONExtract(this=expression.this, expression=expression.expression) - ident = exp.to_identifier("x") - - if isinstance(expression, exp.JSONValueArray): - this: exp.Expression = exp.cast(ident, to=exp.DataType.Type.VARCHAR) - else: - this = exp.ParseJSON(this=f"TO_JSON({ident})") - - transform_lambda = exp.Lambda(expressions=[ident], this=this) - - return self.func("TRANSFORM", json_extract, transform_lambda) - - -class Snowflake(Dialect): - # https://docs.snowflake.com/en/sql-reference/identifiers-syntax - NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE - NULL_ORDERING = "nulls_are_large" - TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" - SUPPORTS_USER_DEFINED_TYPES = False - SUPPORTS_SEMI_ANTI_JOIN = False - PREFER_CTE_ALIAS_COLUMN = True - TABLESAMPLE_SIZE_IS_PERCENT = True - COPY_PARAMS_ARE_CSV = False - ARRAY_AGG_INCLUDES_NULLS = None - - TIME_MAPPING = { - "YYYY": "%Y", - "yyyy": "%Y", - "YY": "%y", - "yy": "%y", - "MMMM": "%B", - "mmmm": "%B", - "MON": "%b", - "mon": "%b", - "MM": "%m", - "mm": "%m", - "DD": "%d", - "dd": "%-d", - "DY": "%a", - "dy": "%w", - "HH24": "%H", - "hh24": "%H", - "HH12": "%I", - "hh12": "%I", - "MI": "%M", - "mi": "%M", - "SS": "%S", - "ss": "%S", - "FF6": "%f", - "ff6": "%f", - } - - DATE_PART_MAPPING = { - **Dialect.DATE_PART_MAPPING, - "ISOWEEK": "WEEKISO", - } - - def quote_identifier(self, expression: E, identify: bool = True) -> E: - # This disables quoting DUAL in SELECT ... FROM DUAL, because Snowflake treats an - # unquoted DUAL keyword in a special way and does not map it to a user-defined table - if ( - isinstance(expression, exp.Identifier) - and isinstance(expression.parent, exp.Table) - and expression.name.lower() == "dual" - ): - return expression # type: ignore - - return super().quote_identifier(expression, identify=identify) - - class JSONPathTokenizer(jsonpath.JSONPathTokenizer): - SINGLE_TOKENS = jsonpath.JSONPathTokenizer.SINGLE_TOKENS.copy() - SINGLE_TOKENS.pop("$") - - class Parser(parser.Parser): - IDENTIFY_PIVOT_STRINGS = True - DEFAULT_SAMPLING_METHOD = "BERNOULLI" - COLON_IS_VARIANT_EXTRACT = True - - ID_VAR_TOKENS = { - *parser.Parser.ID_VAR_TOKENS, - TokenType.MATCH_CONDITION, - } - - TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS | {TokenType.WINDOW} - TABLE_ALIAS_TOKENS.discard(TokenType.MATCH_CONDITION) - - COLON_PLACEHOLDER_TOKENS = ID_VAR_TOKENS | {TokenType.NUMBER} - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, - "ARRAY_CONSTRUCT": lambda args: exp.Array(expressions=args), - "ARRAY_CONTAINS": lambda args: exp.ArrayContains( - this=seq_get(args, 1), expression=seq_get(args, 0) - ), - "ARRAY_GENERATE_RANGE": lambda args: exp.GenerateSeries( - # ARRAY_GENERATE_RANGE has an exlusive end; we normalize it to be inclusive - start=seq_get(args, 0), - end=exp.Sub(this=seq_get(args, 1), expression=exp.Literal.number(1)), - step=seq_get(args, 2), - ), - "BITXOR": _build_bitwise(exp.BitwiseXor, "BITXOR"), - "BIT_XOR": _build_bitwise(exp.BitwiseXor, "BITXOR"), - "BITOR": _build_bitwise(exp.BitwiseOr, "BITOR"), - "BIT_OR": _build_bitwise(exp.BitwiseOr, "BITOR"), - "BITSHIFTLEFT": _build_bitwise(exp.BitwiseLeftShift, "BITSHIFTLEFT"), - "BIT_SHIFTLEFT": _build_bitwise(exp.BitwiseLeftShift, "BIT_SHIFTLEFT"), - "BITSHIFTRIGHT": _build_bitwise(exp.BitwiseRightShift, "BITSHIFTRIGHT"), - "BIT_SHIFTRIGHT": _build_bitwise(exp.BitwiseRightShift, "BIT_SHIFTRIGHT"), - "BOOLXOR": _build_bitwise(exp.Xor, "BOOLXOR"), - "DATE": _build_datetime("DATE", exp.DataType.Type.DATE), - "DATE_TRUNC": _date_trunc_to_time, - "DATEADD": _build_date_time_add(exp.DateAdd), - "DATEDIFF": _build_datediff, - "DIV0": _build_if_from_div0, - "EDITDISTANCE": lambda args: exp.Levenshtein( - this=seq_get(args, 0), expression=seq_get(args, 1), max_dist=seq_get(args, 2) - ), - "FLATTEN": exp.Explode.from_arg_list, - "GET_PATH": lambda args, dialect: exp.JSONExtract( - this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) - ), - "HEX_DECODE_BINARY": exp.Unhex.from_arg_list, - "IFF": exp.If.from_arg_list, - "LAST_DAY": lambda args: exp.LastDay( - this=seq_get(args, 0), unit=map_date_part(seq_get(args, 1)) - ), - "LEN": lambda args: exp.Length(this=seq_get(args, 0), binary=True), - "LENGTH": lambda args: exp.Length(this=seq_get(args, 0), binary=True), - "NULLIFZERO": _build_if_from_nullifzero, - "OBJECT_CONSTRUCT": _build_object_construct, - "REGEXP_EXTRACT_ALL": _build_regexp_extract(exp.RegexpExtractAll), - "REGEXP_REPLACE": _build_regexp_replace, - "REGEXP_SUBSTR": _build_regexp_extract(exp.RegexpExtract), - "REGEXP_SUBSTR_ALL": _build_regexp_extract(exp.RegexpExtractAll), - "RLIKE": exp.RegexpLike.from_arg_list, - "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), - "TABLE": lambda args: exp.TableFromRows(this=seq_get(args, 0)), - "TIMEADD": _build_date_time_add(exp.TimeAdd), - "TIMEDIFF": _build_datediff, - "TIMESTAMPADD": _build_date_time_add(exp.DateAdd), - "TIMESTAMPDIFF": _build_datediff, - "TIMESTAMPFROMPARTS": build_timestamp_from_parts, - "TIMESTAMP_FROM_PARTS": build_timestamp_from_parts, - "TIMESTAMPNTZFROMPARTS": build_timestamp_from_parts, - "TIMESTAMP_NTZ_FROM_PARTS": build_timestamp_from_parts, - "TRY_PARSE_JSON": lambda args: exp.ParseJSON(this=seq_get(args, 0), safe=True), - "TRY_TO_DATE": _build_datetime("TRY_TO_DATE", exp.DataType.Type.DATE, safe=True), - "TRY_TO_TIME": _build_datetime("TRY_TO_TIME", exp.DataType.Type.TIME, safe=True), - "TRY_TO_TIMESTAMP": _build_datetime( - "TRY_TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP, safe=True - ), - "TO_CHAR": build_timetostr_or_tochar, - "TO_DATE": _build_datetime("TO_DATE", exp.DataType.Type.DATE), - "TO_NUMBER": lambda args: exp.ToNumber( - this=seq_get(args, 0), - format=seq_get(args, 1), - precision=seq_get(args, 2), - scale=seq_get(args, 3), - ), - "TO_TIME": _build_datetime("TO_TIME", exp.DataType.Type.TIME), - "TO_TIMESTAMP": _build_datetime("TO_TIMESTAMP", exp.DataType.Type.TIMESTAMP), - "TO_TIMESTAMP_LTZ": _build_datetime("TO_TIMESTAMP_LTZ", exp.DataType.Type.TIMESTAMPLTZ), - "TO_TIMESTAMP_NTZ": _build_datetime("TO_TIMESTAMP_NTZ", exp.DataType.Type.TIMESTAMP), - "TO_TIMESTAMP_TZ": _build_datetime("TO_TIMESTAMP_TZ", exp.DataType.Type.TIMESTAMPTZ), - "TO_VARCHAR": exp.ToChar.from_arg_list, - "ZEROIFNULL": _build_if_from_zeroifnull, - } - - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, - "DATE_PART": lambda self: self._parse_date_part(), - "OBJECT_CONSTRUCT_KEEP_NULL": lambda self: self._parse_json_object(), - "LISTAGG": lambda self: self._parse_string_agg(), - } - FUNCTION_PARSERS.pop("TRIM") - - TIMESTAMPS = parser.Parser.TIMESTAMPS - {TokenType.TIME} - - RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, - TokenType.LIKE_ANY: parser.binary_range_parser(exp.LikeAny), - TokenType.ILIKE_ANY: parser.binary_range_parser(exp.ILikeAny), - } - - ALTER_PARSERS = { - **parser.Parser.ALTER_PARSERS, - "UNSET": lambda self: self.expression( - exp.Set, - tag=self._match_text_seq("TAG"), - expressions=self._parse_csv(self._parse_id_var), - unset=True, - ), - } - - STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, - TokenType.GET: lambda self: self._parse_get(), - TokenType.PUT: lambda self: self._parse_put(), - TokenType.SHOW: lambda self: self._parse_show(), - } - - PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, - "CREDENTIALS": lambda self: self._parse_credentials_property(), - "FILE_FORMAT": lambda self: self._parse_file_format_property(), - "LOCATION": lambda self: self._parse_location_property(), - "TAG": lambda self: self._parse_tag(), - "USING": lambda self: self._match_text_seq("TEMPLATE") - and self.expression(exp.UsingTemplateProperty, this=self._parse_statement()), - } - - TYPE_CONVERTERS = { - # https://docs.snowflake.com/en/sql-reference/data-types-numeric#number - exp.DataType.Type.DECIMAL: build_default_decimal_type(precision=38, scale=0), - } - - SHOW_PARSERS = { - "DATABASES": _show_parser("DATABASES"), - "TERSE DATABASES": _show_parser("DATABASES"), - "SCHEMAS": _show_parser("SCHEMAS"), - "TERSE SCHEMAS": _show_parser("SCHEMAS"), - "OBJECTS": _show_parser("OBJECTS"), - "TERSE OBJECTS": _show_parser("OBJECTS"), - "TABLES": _show_parser("TABLES"), - "TERSE TABLES": _show_parser("TABLES"), - "VIEWS": _show_parser("VIEWS"), - "TERSE VIEWS": _show_parser("VIEWS"), - "PRIMARY KEYS": _show_parser("PRIMARY KEYS"), - "TERSE PRIMARY KEYS": _show_parser("PRIMARY KEYS"), - "IMPORTED KEYS": _show_parser("IMPORTED KEYS"), - "TERSE IMPORTED KEYS": _show_parser("IMPORTED KEYS"), - "UNIQUE KEYS": _show_parser("UNIQUE KEYS"), - "TERSE UNIQUE KEYS": _show_parser("UNIQUE KEYS"), - "SEQUENCES": _show_parser("SEQUENCES"), - "TERSE SEQUENCES": _show_parser("SEQUENCES"), - "STAGES": _show_parser("STAGES"), - "COLUMNS": _show_parser("COLUMNS"), - "USERS": _show_parser("USERS"), - "TERSE USERS": _show_parser("USERS"), - "FILE FORMATS": _show_parser("FILE FORMATS"), - "FUNCTIONS": _show_parser("FUNCTIONS"), - "PROCEDURES": _show_parser("PROCEDURES"), - "WAREHOUSES": _show_parser("WAREHOUSES"), - } - - CONSTRAINT_PARSERS = { - **parser.Parser.CONSTRAINT_PARSERS, - "WITH": lambda self: self._parse_with_constraint(), - "MASKING": lambda self: self._parse_with_constraint(), - "PROJECTION": lambda self: self._parse_with_constraint(), - "TAG": lambda self: self._parse_with_constraint(), - } - - STAGED_FILE_SINGLE_TOKENS = { - TokenType.DOT, - TokenType.MOD, - TokenType.SLASH, - } - - FLATTEN_COLUMNS = ["SEQ", "KEY", "PATH", "INDEX", "VALUE", "THIS"] - - SCHEMA_KINDS = {"OBJECTS", "TABLES", "VIEWS", "SEQUENCES", "UNIQUE KEYS", "IMPORTED KEYS"} - - NON_TABLE_CREATABLES = {"STORAGE INTEGRATION", "TAG", "WAREHOUSE", "STREAMLIT"} - - LAMBDAS = { - **parser.Parser.LAMBDAS, - TokenType.ARROW: lambda self, expressions: self.expression( - exp.Lambda, - this=self._replace_lambda( - self._parse_assignment(), - expressions, - ), - expressions=[e.this if isinstance(e, exp.Cast) else e for e in expressions], - ), - } - - def _parse_use(self) -> exp.Use: - if self._match_text_seq("SECONDARY", "ROLES"): - this = self._match_texts(("ALL", "NONE")) and exp.var(self._prev.text.upper()) - roles = None if this else self._parse_csv(lambda: self._parse_table(schema=False)) - return self.expression( - exp.Use, kind="SECONDARY ROLES", this=this, expressions=roles - ) - - return super()._parse_use() - - def _negate_range( - self, this: t.Optional[exp.Expression] = None - ) -> t.Optional[exp.Expression]: - if not this: - return this - - query = this.args.get("query") - if isinstance(this, exp.In) and isinstance(query, exp.Query): - # Snowflake treats `value NOT IN (subquery)` as `VALUE <> ALL (subquery)`, so - # we do this conversion here to avoid parsing it into `NOT value IN (subquery)` - # which can produce different results (most likely a SnowFlake bug). - # - # https://docs.snowflake.com/en/sql-reference/functions/in - # Context: https://github.com/tobymao/sqlglot/issues/3890 - return self.expression( - exp.NEQ, this=this.this, expression=exp.All(this=query.unnest()) - ) - - return self.expression(exp.Not, this=this) - - def _parse_tag(self) -> exp.Tags: - return self.expression( - exp.Tags, - expressions=self._parse_wrapped_csv(self._parse_property), - ) - - def _parse_with_constraint(self) -> t.Optional[exp.Expression]: - if self._prev.token_type != TokenType.WITH: - self._retreat(self._index - 1) - - if self._match_text_seq("MASKING", "POLICY"): - policy = self._parse_column() - return self.expression( - exp.MaskingPolicyColumnConstraint, - this=policy.to_dot() if isinstance(policy, exp.Column) else policy, - expressions=self._match(TokenType.USING) - and self._parse_wrapped_csv(self._parse_id_var), - ) - if self._match_text_seq("PROJECTION", "POLICY"): - policy = self._parse_column() - return self.expression( - exp.ProjectionPolicyColumnConstraint, - this=policy.to_dot() if isinstance(policy, exp.Column) else policy, - ) - if self._match(TokenType.TAG): - return self._parse_tag() - - return None - - def _parse_with_property(self) -> t.Optional[exp.Expression] | t.List[exp.Expression]: - if self._match(TokenType.TAG): - return self._parse_tag() - - return super()._parse_with_property() - - def _parse_create(self) -> exp.Create | exp.Command: - expression = super()._parse_create() - if isinstance(expression, exp.Create) and expression.kind in self.NON_TABLE_CREATABLES: - # Replace the Table node with the enclosed Identifier - expression.this.replace(expression.this.this) - - return expression - - # https://docs.snowflake.com/en/sql-reference/functions/date_part.html - # https://docs.snowflake.com/en/sql-reference/functions-date-time.html#label-supported-date-time-parts - def _parse_date_part(self: Snowflake.Parser) -> t.Optional[exp.Expression]: - this = self._parse_var() or self._parse_type() - - if not this: - return None - - self._match(TokenType.COMMA) - expression = self._parse_bitwise() - this = map_date_part(this) - name = this.name.upper() - - if name.startswith("EPOCH"): - if name == "EPOCH_MILLISECOND": - scale = 10**3 - elif name == "EPOCH_MICROSECOND": - scale = 10**6 - elif name == "EPOCH_NANOSECOND": - scale = 10**9 - else: - scale = None - - ts = self.expression(exp.Cast, this=expression, to=exp.DataType.build("TIMESTAMP")) - to_unix: exp.Expression = self.expression(exp.TimeToUnix, this=ts) - - if scale: - to_unix = exp.Mul(this=to_unix, expression=exp.Literal.number(scale)) - - return to_unix - - return self.expression(exp.Extract, this=this, expression=expression) - - def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]: - if is_map: - # Keys are strings in Snowflake's objects, see also: - # - https://docs.snowflake.com/en/sql-reference/data-types-semistructured - # - https://docs.snowflake.com/en/sql-reference/functions/object_construct - return self._parse_slice(self._parse_string()) - - return self._parse_slice(self._parse_alias(self._parse_assignment(), explicit=True)) - - def _parse_lateral(self) -> t.Optional[exp.Lateral]: - lateral = super()._parse_lateral() - if not lateral: - return lateral - - if isinstance(lateral.this, exp.Explode): - table_alias = lateral.args.get("alias") - columns = [exp.to_identifier(col) for col in self.FLATTEN_COLUMNS] - if table_alias and not table_alias.args.get("columns"): - table_alias.set("columns", columns) - elif not table_alias: - exp.alias_(lateral, "_flattened", table=columns, copy=False) - - return lateral - - def _parse_table_parts( - self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False - ) -> exp.Table: - # https://docs.snowflake.com/en/user-guide/querying-stage - if self._match(TokenType.STRING, advance=False): - table = self._parse_string() - elif self._match_text_seq("@", advance=False): - table = self._parse_location_path() - else: - table = None - - if table: - file_format = None - pattern = None - - wrapped = self._match(TokenType.L_PAREN) - while self._curr and wrapped and not self._match(TokenType.R_PAREN): - if self._match_text_seq("FILE_FORMAT", "=>"): - file_format = self._parse_string() or super()._parse_table_parts( - is_db_reference=is_db_reference - ) - elif self._match_text_seq("PATTERN", "=>"): - pattern = self._parse_string() - else: - break - - self._match(TokenType.COMMA) - - table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern) - else: - table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference) - - return table - - def _parse_table( - self, - schema: bool = False, - joins: bool = False, - alias_tokens: t.Optional[t.Collection[TokenType]] = None, - parse_bracket: bool = False, - is_db_reference: bool = False, - parse_partition: bool = False, - ) -> t.Optional[exp.Expression]: - table = super()._parse_table( - schema=schema, - joins=joins, - alias_tokens=alias_tokens, - parse_bracket=parse_bracket, - is_db_reference=is_db_reference, - parse_partition=parse_partition, - ) - if isinstance(table, exp.Table) and isinstance(table.this, exp.TableFromRows): - table_from_rows = table.this - for arg in exp.TableFromRows.arg_types: - if arg != "this": - table_from_rows.set(arg, table.args.get(arg)) - - table = table_from_rows - - return table - - def _parse_id_var( - self, - any_token: bool = True, - tokens: t.Optional[t.Collection[TokenType]] = None, - ) -> t.Optional[exp.Expression]: - if self._match_text_seq("IDENTIFIER", "("): - identifier = ( - super()._parse_id_var(any_token=any_token, tokens=tokens) - or self._parse_string() - ) - self._match_r_paren() - return self.expression(exp.Anonymous, this="IDENTIFIER", expressions=[identifier]) - - return super()._parse_id_var(any_token=any_token, tokens=tokens) - - def _parse_show_snowflake(self, this: str) -> exp.Show: - scope = None - scope_kind = None - - # will identity SHOW TERSE SCHEMAS but not SHOW TERSE PRIMARY KEYS - # which is syntactically valid but has no effect on the output - terse = self._tokens[self._index - 2].text.upper() == "TERSE" - - history = self._match_text_seq("HISTORY") - - like = self._parse_string() if self._match(TokenType.LIKE) else None - - if self._match(TokenType.IN): - if self._match_text_seq("ACCOUNT"): - scope_kind = "ACCOUNT" - elif self._match_text_seq("CLASS"): - scope_kind = "CLASS" - scope = self._parse_table_parts() - elif self._match_text_seq("APPLICATION"): - scope_kind = "APPLICATION" - if self._match_text_seq("PACKAGE"): - scope_kind += " PACKAGE" - scope = self._parse_table_parts() - elif self._match_set(self.DB_CREATABLES): - scope_kind = self._prev.text.upper() - if self._curr: - scope = self._parse_table_parts() - elif self._curr: - scope_kind = "SCHEMA" if this in self.SCHEMA_KINDS else "TABLE" - scope = self._parse_table_parts() - - return self.expression( - exp.Show, - **{ - "terse": terse, - "this": this, - "history": history, - "like": like, - "scope": scope, - "scope_kind": scope_kind, - "starts_with": self._match_text_seq("STARTS", "WITH") and self._parse_string(), - "limit": self._parse_limit(), - "from": self._parse_string() if self._match(TokenType.FROM) else None, - "privileges": self._match_text_seq("WITH", "PRIVILEGES") - and self._parse_csv(lambda: self._parse_var(any_token=True, upper=True)), - }, - ) - - def _parse_put(self) -> exp.Put | exp.Command: - if self._curr.token_type != TokenType.STRING: - return self._parse_as_command(self._prev) - - return self.expression( - exp.Put, - this=self._parse_string(), - target=self._parse_location_path(), - properties=self._parse_properties(), - ) - - def _parse_get(self) -> t.Optional[exp.Expression]: - start = self._prev - - # If we detect GET( then we need to parse a function, not a statement - if self._match(TokenType.L_PAREN): - self._retreat(self._index - 2) - return self._parse_expression() - - target = self._parse_location_path() - - # Parse as command if unquoted file path - if self._curr.token_type == TokenType.URI_START: - return self._parse_as_command(start) - - return self.expression( - exp.Get, - this=self._parse_string(), - target=target, - properties=self._parse_properties(), - ) - - def _parse_location_property(self) -> exp.LocationProperty: - self._match(TokenType.EQ) - return self.expression(exp.LocationProperty, this=self._parse_location_path()) - - def _parse_file_location(self) -> t.Optional[exp.Expression]: - # Parse either a subquery or a staged file - return ( - self._parse_select(table=True, parse_subquery_alias=False) - if self._match(TokenType.L_PAREN, advance=False) - else self._parse_table_parts() - ) - - def _parse_location_path(self) -> exp.Var: - start = self._curr - self._advance_any(ignore_reserved=True) - - # We avoid consuming a comma token because external tables like @foo and @bar - # can be joined in a query with a comma separator, as well as closing paren - # in case of subqueries - while self._is_connected() and not self._match_set( - (TokenType.COMMA, TokenType.L_PAREN, TokenType.R_PAREN), advance=False - ): - self._advance_any(ignore_reserved=True) - - return exp.var(self._find_sql(start, self._prev)) - - def _parse_lambda_arg(self) -> t.Optional[exp.Expression]: - this = super()._parse_lambda_arg() - - if not this: - return this - - typ = self._parse_types() - - if typ: - return self.expression(exp.Cast, this=this, to=typ) - - return this - - def _parse_foreign_key(self) -> exp.ForeignKey: - # inlineFK, the REFERENCES columns are implied - if self._match(TokenType.REFERENCES, advance=False): - return self.expression(exp.ForeignKey) - - # outoflineFK, explicitly names the columns - return super()._parse_foreign_key() - - def _parse_file_format_property(self) -> exp.FileFormatProperty: - self._match(TokenType.EQ) - if self._match(TokenType.L_PAREN, advance=False): - expressions = self._parse_wrapped_options() - else: - expressions = [self._parse_format_name()] - - return self.expression( - exp.FileFormatProperty, - expressions=expressions, - ) - - def _parse_credentials_property(self) -> exp.CredentialsProperty: - return self.expression( - exp.CredentialsProperty, - expressions=self._parse_wrapped_options(), - ) - - class Tokenizer(tokens.Tokenizer): - STRING_ESCAPES = ["\\", "'"] - HEX_STRINGS = [("x'", "'"), ("X'", "'")] - RAW_STRINGS = ["$$"] - COMMENTS = ["--", "//", ("/*", "*/")] - NESTED_COMMENTS = False - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - "FILE://": TokenType.URI_START, - "BYTEINT": TokenType.INT, - "EXCLUDE": TokenType.EXCEPT, - "FILE FORMAT": TokenType.FILE_FORMAT, - "GET": TokenType.GET, - "ILIKE ANY": TokenType.ILIKE_ANY, - "LIKE ANY": TokenType.LIKE_ANY, - "MATCH_CONDITION": TokenType.MATCH_CONDITION, - "MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE, - "MINUS": TokenType.EXCEPT, - "NCHAR VARYING": TokenType.VARCHAR, - "PUT": TokenType.PUT, - "REMOVE": TokenType.COMMAND, - "RM": TokenType.COMMAND, - "SAMPLE": TokenType.TABLE_SAMPLE, - "SQL_DOUBLE": TokenType.DOUBLE, - "SQL_VARCHAR": TokenType.VARCHAR, - "STORAGE INTEGRATION": TokenType.STORAGE_INTEGRATION, - "TAG": TokenType.TAG, - "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, - "TOP": TokenType.TOP, - "WAREHOUSE": TokenType.WAREHOUSE, - "STAGE": TokenType.STAGE, - "STREAMLIT": TokenType.STREAMLIT, - } - KEYWORDS.pop("/*+") - - SINGLE_TOKENS = { - **tokens.Tokenizer.SINGLE_TOKENS, - "$": TokenType.PARAMETER, - } - - VAR_SINGLE_TOKENS = {"$"} - - COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SHOW} - - class Generator(generator.Generator): - PARAMETER_TOKEN = "$" - MATCHED_BY_SOURCE = False - SINGLE_STRING_INTERVAL = True - JOIN_HINTS = False - TABLE_HINTS = False - QUERY_HINTS = False - AGGREGATE_FILTER_SUPPORTED = False - SUPPORTS_TABLE_COPY = False - COLLATE_IS_FUNC = True - LIMIT_ONLY_LITERALS = True - JSON_KEY_VALUE_PAIR_SEP = "," - INSERT_OVERWRITE = " OVERWRITE INTO" - STRUCT_DELIMITER = ("(", ")") - COPY_PARAMS_ARE_WRAPPED = False - COPY_PARAMS_EQ_REQUIRED = True - STAR_EXCEPT = "EXCLUDE" - SUPPORTS_EXPLODING_PROJECTIONS = False - ARRAY_CONCAT_IS_VAR_LEN = False - SUPPORTS_CONVERT_TIMEZONE = True - EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE = False - SUPPORTS_MEDIAN = True - ARRAY_SIZE_NAME = "ARRAY_SIZE" - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), - exp.ArgMax: rename_func("MAX_BY"), - exp.ArgMin: rename_func("MIN_BY"), - exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CAT"), - exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this), - exp.ArrayIntersect: rename_func("ARRAY_INTERSECTION"), - exp.AtTimeZone: lambda self, e: self.func( - "CONVERT_TIMEZONE", e.args.get("zone"), e.this - ), - exp.BitwiseOr: rename_func("BITOR"), - exp.BitwiseXor: rename_func("BITXOR"), - exp.BitwiseLeftShift: rename_func("BITSHIFTLEFT"), - exp.BitwiseRightShift: rename_func("BITSHIFTRIGHT"), - exp.Create: transforms.preprocess([_flatten_structured_types_unless_iceberg]), - exp.DateAdd: date_delta_sql("DATEADD"), - exp.DateDiff: date_delta_sql("DATEDIFF"), - exp.DatetimeAdd: date_delta_sql("TIMESTAMPADD"), - exp.DatetimeDiff: timestampdiff_sql, - exp.DateStrToDate: datestrtodate_sql, - exp.DayOfMonth: rename_func("DAYOFMONTH"), - exp.DayOfWeek: rename_func("DAYOFWEEK"), - exp.DayOfWeekIso: rename_func("DAYOFWEEKISO"), - exp.DayOfYear: rename_func("DAYOFYEAR"), - exp.Explode: rename_func("FLATTEN"), - exp.Extract: lambda self, e: self.func( - "DATE_PART", map_date_part(e.this, self.dialect), e.expression - ), - exp.FileFormatProperty: lambda self, - e: f"FILE_FORMAT=({self.expressions(e, 'expressions', sep=' ')})", - exp.FromTimeZone: lambda self, e: self.func( - "CONVERT_TIMEZONE", e.args.get("zone"), "'UTC'", e.this - ), - exp.GenerateSeries: lambda self, e: self.func( - "ARRAY_GENERATE_RANGE", e.args["start"], e.args["end"] + 1, e.args.get("step") - ), - exp.GroupConcat: lambda self, e: groupconcat_sql(self, e, sep=""), - exp.If: if_sql(name="IFF", false_value="NULL"), - exp.JSONExtractArray: _json_extract_value_array_sql, - exp.JSONExtractScalar: lambda self, e: self.func( - "JSON_EXTRACT_PATH_TEXT", e.this, e.expression - ), - exp.JSONObject: lambda self, e: self.func("OBJECT_CONSTRUCT_KEEP_NULL", *e.expressions), - exp.JSONPathRoot: lambda *_: "", - exp.JSONValueArray: _json_extract_value_array_sql, - exp.Levenshtein: unsupported_args("ins_cost", "del_cost", "sub_cost")( - rename_func("EDITDISTANCE") - ), - exp.LocationProperty: lambda self, e: f"LOCATION={self.sql(e, 'this')}", - exp.LogicalAnd: rename_func("BOOLAND_AGG"), - exp.LogicalOr: rename_func("BOOLOR_AGG"), - exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), - exp.MakeInterval: no_make_interval_sql, - exp.Max: max_or_greatest, - exp.Min: min_or_least, - exp.ParseJSON: lambda self, e: self.func( - "TRY_PARSE_JSON" if e.args.get("safe") else "PARSE_JSON", e.this - ), - exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", - exp.PercentileCont: transforms.preprocess( - [transforms.add_within_group_for_percentiles] - ), - exp.PercentileDisc: transforms.preprocess( - [transforms.add_within_group_for_percentiles] - ), - exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]), - exp.RegexpExtract: _regexpextract_sql, - exp.RegexpExtractAll: _regexpextract_sql, - exp.RegexpILike: _regexpilike_sql, - exp.Rand: rename_func("RANDOM"), - exp.Select: transforms.preprocess( - [ - transforms.eliminate_window_clause, - transforms.eliminate_distinct_on, - transforms.explode_projection_to_unnest(), - transforms.eliminate_semi_and_anti_joins, - _transform_generate_date_array, - ] - ), - exp.SHA: rename_func("SHA1"), - exp.StarMap: rename_func("OBJECT_CONSTRUCT"), - exp.StartsWith: rename_func("STARTSWITH"), - exp.EndsWith: rename_func("ENDSWITH"), - exp.StrPosition: lambda self, e: strposition_sql( - self, e, func_name="CHARINDEX", supports_position=True - ), - exp.StrToDate: lambda self, e: self.func("DATE", e.this, self.format_time(e)), - exp.StringToArray: rename_func("STRTOK_TO_ARRAY"), - exp.Stuff: rename_func("INSERT"), - exp.StPoint: rename_func("ST_MAKEPOINT"), - exp.TimeAdd: date_delta_sql("TIMEADD"), - exp.Timestamp: no_timestamp_sql, - exp.TimestampAdd: date_delta_sql("TIMESTAMPADD"), - exp.TimestampDiff: lambda self, e: self.func( - "TIMESTAMPDIFF", e.unit, e.expression, e.this - ), - exp.TimestampTrunc: timestamptrunc_sql(), - exp.TimeStrToTime: timestrtotime_sql, - exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", - exp.ToArray: rename_func("TO_ARRAY"), - exp.ToChar: lambda self, e: self.function_fallback_sql(e), - exp.ToDouble: rename_func("TO_DOUBLE"), - exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), - exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), - exp.TsOrDsToDate: lambda self, e: self.func( - "TRY_TO_DATE" if e.args.get("safe") else "TO_DATE", e.this, self.format_time(e) - ), - exp.TsOrDsToTime: lambda self, e: self.func( - "TRY_TO_TIME" if e.args.get("safe") else "TO_TIME", e.this, self.format_time(e) - ), - exp.Unhex: rename_func("HEX_DECODE_BINARY"), - exp.UnixToTime: rename_func("TO_TIMESTAMP"), - exp.Uuid: rename_func("UUID_STRING"), - exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), - exp.WeekOfYear: rename_func("WEEKOFYEAR"), - exp.Xor: rename_func("BOOLXOR"), - } - - SUPPORTED_JSON_PATH_PARTS = { - exp.JSONPathKey, - exp.JSONPathRoot, - exp.JSONPathSubscript, - } - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.NESTED: "OBJECT", - exp.DataType.Type.STRUCT: "OBJECT", - exp.DataType.Type.BIGDECIMAL: "DOUBLE", - } - - TOKEN_MAPPING = { - TokenType.AUTO_INCREMENT: "AUTOINCREMENT", - } - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.CredentialsProperty: exp.Properties.Location.POST_WITH, - exp.LocationProperty: exp.Properties.Location.POST_WITH, - exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, - exp.SetProperty: exp.Properties.Location.UNSUPPORTED, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, - } - - UNSUPPORTED_VALUES_EXPRESSIONS = { - exp.Map, - exp.StarMap, - exp.Struct, - exp.VarMap, - } - - RESPECT_IGNORE_NULLS_UNSUPPORTED_EXPRESSIONS = (exp.ArrayAgg,) - - def with_properties(self, properties: exp.Properties) -> str: - return self.properties(properties, wrapped=False, prefix=self.sep(""), sep=" ") - - def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str: - if expression.find(*self.UNSUPPORTED_VALUES_EXPRESSIONS): - values_as_table = False - - return super().values_sql(expression, values_as_table=values_as_table) - - def datatype_sql(self, expression: exp.DataType) -> str: - expressions = expression.expressions - if ( - expressions - and expression.is_type(*exp.DataType.STRUCT_TYPES) - and any(isinstance(field_type, exp.DataType) for field_type in expressions) - ): - # The correct syntax is OBJECT [ ( str: - return self.func( - "TO_NUMBER", - expression.this, - expression.args.get("format"), - expression.args.get("precision"), - expression.args.get("scale"), - ) - - def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: - milli = expression.args.get("milli") - if milli is not None: - milli_to_nano = milli.pop() * exp.Literal.number(1000000) - expression.set("nano", milli_to_nano) - - return rename_func("TIMESTAMP_FROM_PARTS")(self, expression) - - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - if expression.is_type(exp.DataType.Type.GEOGRAPHY): - return self.func("TO_GEOGRAPHY", expression.this) - if expression.is_type(exp.DataType.Type.GEOMETRY): - return self.func("TO_GEOMETRY", expression.this) - - return super().cast_sql(expression, safe_prefix=safe_prefix) - - def trycast_sql(self, expression: exp.TryCast) -> str: - value = expression.this - - if value.type is None: - from sqlglot.optimizer.annotate_types import annotate_types - - value = annotate_types(value, dialect=self.dialect) - - if value.is_type(*exp.DataType.TEXT_TYPES, exp.DataType.Type.UNKNOWN): - return super().trycast_sql(expression) - - # TRY_CAST only works for string values in Snowflake - return self.cast_sql(expression) - - def log_sql(self, expression: exp.Log) -> str: - if not expression.expression: - return self.func("LN", expression.this) - - return super().log_sql(expression) - - def unnest_sql(self, expression: exp.Unnest) -> str: - unnest_alias = expression.args.get("alias") - offset = expression.args.get("offset") - - unnest_alias_columns = unnest_alias.columns if unnest_alias else [] - value = seq_get(unnest_alias_columns, 0) or exp.to_identifier("value") - - columns = [ - exp.to_identifier("seq"), - exp.to_identifier("key"), - exp.to_identifier("path"), - offset.pop() if isinstance(offset, exp.Expression) else exp.to_identifier("index"), - value, - exp.to_identifier("this"), - ] - - if unnest_alias: - unnest_alias.set("columns", columns) - else: - unnest_alias = exp.TableAlias(this="_u", columns=columns) - - table_input = self.sql(expression.expressions[0]) - if not table_input.startswith("INPUT =>"): - table_input = f"INPUT => {table_input}" - - explode = f"TABLE(FLATTEN({table_input}))" - alias = self.sql(unnest_alias) - alias = f" AS {alias}" if alias else "" - value = "" if isinstance(expression.parent, (exp.From, exp.Join)) else f"{value} FROM " - - return f"{value}{explode}{alias}" - - def show_sql(self, expression: exp.Show) -> str: - terse = "TERSE " if expression.args.get("terse") else "" - history = " HISTORY" if expression.args.get("history") else "" - like = self.sql(expression, "like") - like = f" LIKE {like}" if like else "" - - scope = self.sql(expression, "scope") - scope = f" {scope}" if scope else "" - - scope_kind = self.sql(expression, "scope_kind") - if scope_kind: - scope_kind = f" IN {scope_kind}" - - starts_with = self.sql(expression, "starts_with") - if starts_with: - starts_with = f" STARTS WITH {starts_with}" - - limit = self.sql(expression, "limit") - - from_ = self.sql(expression, "from") - if from_: - from_ = f" FROM {from_}" - - privileges = self.expressions(expression, key="privileges", flat=True) - privileges = f" WITH PRIVILEGES {privileges}" if privileges else "" - - return f"SHOW {terse}{expression.name}{history}{like}{scope_kind}{scope}{starts_with}{limit}{from_}{privileges}" - - def describe_sql(self, expression: exp.Describe) -> str: - # Default to table if kind is unknown - kind_value = expression.args.get("kind") or "TABLE" - kind = f" {kind_value}" if kind_value else "" - this = f" {self.sql(expression, 'this')}" - expressions = self.expressions(expression, flat=True) - expressions = f" {expressions}" if expressions else "" - return f"DESCRIBE{kind}{this}{expressions}" - - def generatedasidentitycolumnconstraint_sql( - self, expression: exp.GeneratedAsIdentityColumnConstraint - ) -> str: - start = expression.args.get("start") - start = f" START {start}" if start else "" - increment = expression.args.get("increment") - increment = f" INCREMENT {increment}" if increment else "" - return f"AUTOINCREMENT{start}{increment}" - - def cluster_sql(self, expression: exp.Cluster) -> str: - return f"CLUSTER BY ({self.expressions(expression, flat=True)})" - - def struct_sql(self, expression: exp.Struct) -> str: - keys = [] - values = [] - - for i, e in enumerate(expression.expressions): - if isinstance(e, exp.PropertyEQ): - keys.append( - exp.Literal.string(e.name) if isinstance(e.this, exp.Identifier) else e.this - ) - values.append(e.expression) - else: - keys.append(exp.Literal.string(f"_{i}")) - values.append(e) - - return self.func("OBJECT_CONSTRUCT", *flatten(zip(keys, values))) - - @unsupported_args("weight", "accuracy") - def approxquantile_sql(self, expression: exp.ApproxQuantile) -> str: - return self.func("APPROX_PERCENTILE", expression.this, expression.args.get("quantile")) - - def alterset_sql(self, expression: exp.AlterSet) -> str: - exprs = self.expressions(expression, flat=True) - exprs = f" {exprs}" if exprs else "" - file_format = self.expressions(expression, key="file_format", flat=True, sep=" ") - file_format = f" STAGE_FILE_FORMAT = ({file_format})" if file_format else "" - copy_options = self.expressions(expression, key="copy_options", flat=True, sep=" ") - copy_options = f" STAGE_COPY_OPTIONS = ({copy_options})" if copy_options else "" - tag = self.expressions(expression, key="tag", flat=True) - tag = f" TAG {tag}" if tag else "" - - return f"SET{exprs}{file_format}{copy_options}{tag}" - - def strtotime_sql(self, expression: exp.StrToTime): - safe_prefix = "TRY_" if expression.args.get("safe") else "" - return self.func( - f"{safe_prefix}TO_TIMESTAMP", expression.this, self.format_time(expression) - ) - - def timestampsub_sql(self, expression: exp.TimestampSub): - return self.sql( - exp.TimestampAdd( - this=expression.this, - expression=expression.expression * -1, - unit=expression.unit, - ) - ) - - def jsonextract_sql(self, expression: exp.JSONExtract): - this = expression.this - - # JSON strings are valid coming from other dialects such as BQ - return self.func( - "GET_PATH", - exp.ParseJSON(this=this) if this.is_string else this, - expression.expression, - ) - - def timetostr_sql(self, expression: exp.TimeToStr) -> str: - this = expression.this - if not isinstance(this, exp.TsOrDsToTimestamp): - this = exp.cast(this, exp.DataType.Type.TIMESTAMP) - - return self.func("TO_CHAR", this, self.format_time(expression)) - - def datesub_sql(self, expression: exp.DateSub) -> str: - value = expression.expression - if value: - value.replace(value * (-1)) - else: - self.unsupported("DateSub cannot be transpiled if the subtracted count is unknown") - - return date_delta_sql("DATEADD")(self, expression) - - def select_sql(self, expression: exp.Select) -> str: - limit = expression.args.get("limit") - offset = expression.args.get("offset") - if offset and not limit: - expression.limit(exp.Null(), copy=False) - return super().select_sql(expression) - - def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: - is_materialized = expression.find(exp.MaterializedProperty) - copy_grants_property = expression.find(exp.CopyGrantsProperty) - - if expression.kind == "VIEW" and is_materialized and copy_grants_property: - # For materialized views, COPY GRANTS is located *before* the columns list - # This is in contrast to normal views where COPY GRANTS is located *after* the columns list - # We default CopyGrantsProperty to POST_SCHEMA which means we need to output it POST_NAME if a materialized view is detected - # ref: https://docs.snowflake.com/en/sql-reference/sql/create-materialized-view#syntax - # ref: https://docs.snowflake.com/en/sql-reference/sql/create-view#syntax - post_schema_properties = locations[exp.Properties.Location.POST_SCHEMA] - post_schema_properties.pop(post_schema_properties.index(copy_grants_property)) - - this_name = self.sql(expression.this, "this") - copy_grants = self.sql(copy_grants_property) - this_schema = self.schema_columns_sql(expression.this) - this_schema = f"{self.sep()}{this_schema}" if this_schema else "" - - return f"{this_name}{self.sep()}{copy_grants}{this_schema}" - - return super().createable_sql(expression, locations) - - def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: - this = expression.this - - # If an ORDER BY clause is present, we need to remove it from ARRAY_AGG - # and add it later as part of the WITHIN GROUP clause - order = this if isinstance(this, exp.Order) else None - if order: - expression.set("this", order.this.pop()) - - expr_sql = super().arrayagg_sql(expression) - - if order: - expr_sql = self.sql(exp.WithinGroup(this=expr_sql, expression=order)) - - return expr_sql - - def array_sql(self, expression: exp.Array) -> str: - expressions = expression.expressions - - first_expr = seq_get(expressions, 0) - if isinstance(first_expr, exp.Select): - # SELECT AS STRUCT foo AS alias_foo -> ARRAY_AGG(OBJECT_CONSTRUCT('alias_foo', foo)) - if first_expr.text("kind").upper() == "STRUCT": - object_construct_args = [] - for expr in first_expr.expressions: - # Alias case: SELECT AS STRUCT foo AS alias_foo -> OBJECT_CONSTRUCT('alias_foo', foo) - # Column case: SELECT AS STRUCT foo -> OBJECT_CONSTRUCT('foo', foo) - name = expr.this if isinstance(expr, exp.Alias) else expr - - object_construct_args.extend([exp.Literal.string(expr.alias_or_name), name]) - - array_agg = exp.ArrayAgg( - this=_build_object_construct(args=object_construct_args) - ) - - first_expr.set("kind", None) - first_expr.set("expressions", [array_agg]) - - return self.sql(first_expr.subquery()) - - return inline_array_sql(self, expression) diff --git a/altimate_packages/sqlglot/dialects/spark.py b/altimate_packages/sqlglot/dialects/spark.py deleted file mode 100644 index d7055534d..000000000 --- a/altimate_packages/sqlglot/dialects/spark.py +++ /dev/null @@ -1,202 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp -from sqlglot.dialects.dialect import rename_func, unit_to_var, timestampdiff_sql, build_date_delta -from sqlglot.dialects.hive import _build_with_ignore_nulls -from sqlglot.dialects.spark2 import Spark2, temporary_storage_provider, _build_as_cast -from sqlglot.helper import ensure_list, seq_get -from sqlglot.transforms import ( - ctas_with_tmp_tables_to_create_tmp_view, - remove_unique_constraints, - preprocess, - move_partitioned_by_to_schema_columns, -) - - -def _build_datediff(args: t.List) -> exp.Expression: - """ - Although Spark docs don't mention the "unit" argument, Spark3 added support for - it at some point. Databricks also supports this variant (see below). - - For example, in spark-sql (v3.3.1): - - SELECT DATEDIFF('2020-01-01', '2020-01-05') results in -4 - - SELECT DATEDIFF(day, '2020-01-01', '2020-01-05') results in 4 - - See also: - - https://docs.databricks.com/sql/language-manual/functions/datediff3.html - - https://docs.databricks.com/sql/language-manual/functions/datediff.html - """ - unit = None - this = seq_get(args, 0) - expression = seq_get(args, 1) - - if len(args) == 3: - unit = exp.var(t.cast(exp.Expression, this).name) - this = args[2] - - return exp.DateDiff( - this=exp.TsOrDsToDate(this=this), expression=exp.TsOrDsToDate(this=expression), unit=unit - ) - - -def _build_dateadd(args: t.List) -> exp.Expression: - expression = seq_get(args, 1) - - if len(args) == 2: - # DATE_ADD(startDate, numDays INTEGER) - # https://docs.databricks.com/en/sql/language-manual/functions/date_add.html - return exp.TsOrDsAdd( - this=seq_get(args, 0), expression=expression, unit=exp.Literal.string("DAY") - ) - - # DATE_ADD / DATEADD / TIMESTAMPADD(unit, value integer, expr) - # https://docs.databricks.com/en/sql/language-manual/functions/date_add3.html - return exp.TimestampAdd(this=seq_get(args, 2), expression=expression, unit=seq_get(args, 0)) - - -def _normalize_partition(e: exp.Expression) -> exp.Expression: - """Normalize the expressions in PARTITION BY (, , ...)""" - if isinstance(e, str): - return exp.to_identifier(e) - if isinstance(e, exp.Literal): - return exp.to_identifier(e.name) - return e - - -def _dateadd_sql(self: Spark.Generator, expression: exp.TsOrDsAdd | exp.TimestampAdd) -> str: - if not expression.unit or ( - isinstance(expression, exp.TsOrDsAdd) and expression.text("unit").upper() == "DAY" - ): - # Coming from Hive/Spark2 DATE_ADD or roundtripping the 2-arg version of Spark3/DB - return self.func("DATE_ADD", expression.this, expression.expression) - - this = self.func( - "DATE_ADD", - unit_to_var(expression), - expression.expression, - expression.this, - ) - - if isinstance(expression, exp.TsOrDsAdd): - # The 3 arg version of DATE_ADD produces a timestamp in Spark3/DB but possibly not - # in other dialects - return_type = expression.return_type - if not return_type.is_type(exp.DataType.Type.TIMESTAMP, exp.DataType.Type.DATETIME): - this = f"CAST({this} AS {return_type})" - - return this - - -class Spark(Spark2): - SUPPORTS_ORDER_BY_ALL = True - - class Tokenizer(Spark2.Tokenizer): - STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False - - RAW_STRINGS = [ - (prefix + q, q) - for q in t.cast(t.List[str], Spark2.Tokenizer.QUOTES) - for prefix in ("r", "R") - ] - - class Parser(Spark2.Parser): - FUNCTIONS = { - **Spark2.Parser.FUNCTIONS, - "ANY_VALUE": _build_with_ignore_nulls(exp.AnyValue), - "DATE_ADD": _build_dateadd, - "DATEADD": _build_dateadd, - "TIMESTAMPADD": _build_dateadd, - "TIMESTAMPDIFF": build_date_delta(exp.TimestampDiff), - "DATEDIFF": _build_datediff, - "DATE_DIFF": _build_datediff, - "TIMESTAMP_LTZ": _build_as_cast("TIMESTAMP_LTZ"), - "TIMESTAMP_NTZ": _build_as_cast("TIMESTAMP_NTZ"), - "TRY_ELEMENT_AT": lambda args: exp.Bracket( - this=seq_get(args, 0), - expressions=ensure_list(seq_get(args, 1)), - offset=1, - safe=True, - ), - } - - def _parse_generated_as_identity( - self, - ) -> ( - exp.GeneratedAsIdentityColumnConstraint - | exp.ComputedColumnConstraint - | exp.GeneratedAsRowColumnConstraint - ): - this = super()._parse_generated_as_identity() - if this.expression: - return self.expression(exp.ComputedColumnConstraint, this=this.expression) - return this - - class Generator(Spark2.Generator): - SUPPORTS_TO_NUMBER = True - PAD_FILL_PATTERN_IS_REQUIRED = False - SUPPORTS_CONVERT_TIMEZONE = True - SUPPORTS_MEDIAN = True - SUPPORTS_UNIX_SECONDS = True - - TYPE_MAPPING = { - **Spark2.Generator.TYPE_MAPPING, - exp.DataType.Type.MONEY: "DECIMAL(15, 4)", - exp.DataType.Type.SMALLMONEY: "DECIMAL(6, 4)", - exp.DataType.Type.UUID: "STRING", - exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP_LTZ", - exp.DataType.Type.TIMESTAMPNTZ: "TIMESTAMP_NTZ", - } - - TRANSFORMS = { - **Spark2.Generator.TRANSFORMS, - exp.ArrayConstructCompact: lambda self, e: self.func( - "ARRAY_COMPACT", self.func("ARRAY", *e.expressions) - ), - exp.Create: preprocess( - [ - remove_unique_constraints, - lambda e: ctas_with_tmp_tables_to_create_tmp_view( - e, temporary_storage_provider - ), - move_partitioned_by_to_schema_columns, - ] - ), - exp.EndsWith: rename_func("ENDSWITH"), - exp.PartitionedByProperty: lambda self, - e: f"PARTITIONED BY {self.wrap(self.expressions(sqls=[_normalize_partition(e) for e in e.this.expressions], skip_first=True))}", - exp.StartsWith: rename_func("STARTSWITH"), - exp.TsOrDsAdd: _dateadd_sql, - exp.TimestampAdd: _dateadd_sql, - exp.DatetimeDiff: timestampdiff_sql, - exp.TimestampDiff: timestampdiff_sql, - exp.TryCast: lambda self, e: ( - self.trycast_sql(e) if e.args.get("safe") else self.cast_sql(e) - ), - } - TRANSFORMS.pop(exp.AnyValue) - TRANSFORMS.pop(exp.DateDiff) - TRANSFORMS.pop(exp.Group) - - def bracket_sql(self, expression: exp.Bracket) -> str: - if expression.args.get("safe"): - key = seq_get(self.bracket_offset_expressions(expression, index_offset=1), 0) - return self.func("TRY_ELEMENT_AT", expression.this, key) - - return super().bracket_sql(expression) - - def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str: - return f"GENERATED ALWAYS AS ({self.sql(expression, 'this')})" - - def anyvalue_sql(self, expression: exp.AnyValue) -> str: - return self.function_fallback_sql(expression) - - def datediff_sql(self, expression: exp.DateDiff) -> str: - end = self.sql(expression, "this") - start = self.sql(expression, "expression") - - if expression.unit: - return self.func("DATEDIFF", unit_to_var(expression), start, end) - - return self.func("DATEDIFF", end, start) diff --git a/altimate_packages/sqlglot/dialects/spark2.py b/altimate_packages/sqlglot/dialects/spark2.py deleted file mode 100644 index 0c9088758..000000000 --- a/altimate_packages/sqlglot/dialects/spark2.py +++ /dev/null @@ -1,349 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp, transforms -from sqlglot.dialects.dialect import ( - binary_from_function, - build_formatted_time, - is_parse_json, - pivot_column_names, - rename_func, - trim_sql, - unit_to_str, -) -from sqlglot.dialects.hive import Hive -from sqlglot.helper import seq_get, ensure_list -from sqlglot.tokens import TokenType -from sqlglot.transforms import ( - preprocess, - remove_unique_constraints, - ctas_with_tmp_tables_to_create_tmp_view, - move_schema_columns_to_partitioned_by, -) - -if t.TYPE_CHECKING: - from sqlglot._typing import E - - from sqlglot.optimizer.annotate_types import TypeAnnotator - - -def _map_sql(self: Spark2.Generator, expression: exp.Map) -> str: - keys = expression.args.get("keys") - values = expression.args.get("values") - - if not keys or not values: - return self.func("MAP") - - return self.func("MAP_FROM_ARRAYS", keys, values) - - -def _build_as_cast(to_type: str) -> t.Callable[[t.List], exp.Expression]: - return lambda args: exp.Cast(this=seq_get(args, 0), to=exp.DataType.build(to_type)) - - -def _str_to_date(self: Spark2.Generator, expression: exp.StrToDate) -> str: - time_format = self.format_time(expression) - if time_format == Hive.DATE_FORMAT: - return self.func("TO_DATE", expression.this) - return self.func("TO_DATE", expression.this, time_format) - - -def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str: - scale = expression.args.get("scale") - timestamp = expression.this - - if scale is None: - return self.sql(exp.cast(exp.func("from_unixtime", timestamp), exp.DataType.Type.TIMESTAMP)) - if scale == exp.UnixToTime.SECONDS: - return self.func("TIMESTAMP_SECONDS", timestamp) - if scale == exp.UnixToTime.MILLIS: - return self.func("TIMESTAMP_MILLIS", timestamp) - if scale == exp.UnixToTime.MICROS: - return self.func("TIMESTAMP_MICROS", timestamp) - - unix_seconds = exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)) - return self.func("TIMESTAMP_SECONDS", unix_seconds) - - -def _unalias_pivot(expression: exp.Expression) -> exp.Expression: - """ - Spark doesn't allow PIVOT aliases, so we need to remove them and possibly wrap a - pivoted source in a subquery with the same alias to preserve the query's semantics. - - Example: - >>> from sqlglot import parse_one - >>> expr = parse_one("SELECT piv.x FROM tbl PIVOT (SUM(a) FOR b IN ('x')) piv") - >>> print(_unalias_pivot(expr).sql(dialect="spark")) - SELECT piv.x FROM (SELECT * FROM tbl PIVOT(SUM(a) FOR b IN ('x'))) AS piv - """ - if isinstance(expression, exp.From) and expression.this.args.get("pivots"): - pivot = expression.this.args["pivots"][0] - if pivot.alias: - alias = pivot.args["alias"].pop() - return exp.From( - this=expression.this.replace( - exp.select("*") - .from_(expression.this.copy(), copy=False) - .subquery(alias=alias, copy=False) - ) - ) - - return expression - - -def _unqualify_pivot_columns(expression: exp.Expression) -> exp.Expression: - """ - Spark doesn't allow the column referenced in the PIVOT's field to be qualified, - so we need to unqualify it. - - Example: - >>> from sqlglot import parse_one - >>> expr = parse_one("SELECT * FROM tbl PIVOT (SUM(tbl.sales) FOR tbl.quarter IN ('Q1', 'Q2'))") - >>> print(_unqualify_pivot_columns(expr).sql(dialect="spark")) - SELECT * FROM tbl PIVOT(SUM(tbl.sales) FOR quarter IN ('Q1', 'Q1')) - """ - if isinstance(expression, exp.Pivot): - expression.set( - "fields", [transforms.unqualify_columns(field) for field in expression.fields] - ) - - return expression - - -def temporary_storage_provider(expression: exp.Expression) -> exp.Expression: - # spark2, spark, Databricks require a storage provider for temporary tables - provider = exp.FileFormatProperty(this=exp.Literal.string("parquet")) - expression.args["properties"].append("expressions", provider) - return expression - - -def _annotate_by_similar_args( - self: TypeAnnotator, expression: E, *args: str, target_type: exp.DataType | exp.DataType.Type -) -> E: - """ - Infers the type of the expression according to the following rules: - - If all args are of the same type OR any arg is of target_type, the expr is inferred as such - - If any arg is of UNKNOWN type and none of target_type, the expr is inferred as UNKNOWN - """ - self._annotate_args(expression) - - expressions: t.List[exp.Expression] = [] - for arg in args: - arg_expr = expression.args.get(arg) - expressions.extend(expr for expr in ensure_list(arg_expr) if expr) - - last_datatype = None - - has_unknown = False - for expr in expressions: - if expr.is_type(exp.DataType.Type.UNKNOWN): - has_unknown = True - elif expr.is_type(target_type): - has_unknown = False - last_datatype = target_type - break - else: - last_datatype = expr.type - - self._set_type(expression, exp.DataType.Type.UNKNOWN if has_unknown else last_datatype) - return expression - - -class Spark2(Hive): - ANNOTATORS = { - **Hive.ANNOTATORS, - exp.Substring: lambda self, e: self._annotate_by_args(e, "this"), - exp.Concat: lambda self, e: _annotate_by_similar_args( - self, e, "expressions", target_type=exp.DataType.Type.TEXT - ), - exp.Pad: lambda self, e: _annotate_by_similar_args( - self, e, "this", "fill_pattern", target_type=exp.DataType.Type.TEXT - ), - } - - class Tokenizer(Hive.Tokenizer): - HEX_STRINGS = [("X'", "'"), ("x'", "'")] - - KEYWORDS = { - **Hive.Tokenizer.KEYWORDS, - "TIMESTAMP": TokenType.TIMESTAMPTZ, - } - - class Parser(Hive.Parser): - TRIM_PATTERN_FIRST = True - - FUNCTIONS = { - **Hive.Parser.FUNCTIONS, - "AGGREGATE": exp.Reduce.from_arg_list, - "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, - "BOOLEAN": _build_as_cast("boolean"), - "DATE": _build_as_cast("date"), - "DATE_TRUNC": lambda args: exp.TimestampTrunc( - this=seq_get(args, 1), unit=exp.var(seq_get(args, 0)) - ), - "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - "DOUBLE": _build_as_cast("double"), - "FLOAT": _build_as_cast("float"), - "FROM_UTC_TIMESTAMP": lambda args, dialect: exp.AtTimeZone( - this=exp.cast( - seq_get(args, 0) or exp.Var(this=""), - exp.DataType.Type.TIMESTAMP, - dialect=dialect, - ), - zone=seq_get(args, 1), - ), - "INT": _build_as_cast("int"), - "MAP_FROM_ARRAYS": exp.Map.from_arg_list, - "RLIKE": exp.RegexpLike.from_arg_list, - "SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift), - "SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift), - "STRING": _build_as_cast("string"), - "TIMESTAMP": _build_as_cast("timestamp"), - "TO_TIMESTAMP": lambda args: ( - _build_as_cast("timestamp")(args) - if len(args) == 1 - else build_formatted_time(exp.StrToTime, "spark")(args) - ), - "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, - "TO_UTC_TIMESTAMP": lambda args, dialect: exp.FromTimeZone( - this=exp.cast( - seq_get(args, 0) or exp.Var(this=""), - exp.DataType.Type.TIMESTAMP, - dialect=dialect, - ), - zone=seq_get(args, 1), - ), - "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)), - "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))), - } - - FUNCTION_PARSERS = { - **Hive.Parser.FUNCTION_PARSERS, - "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), - "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), - "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), - "MERGE": lambda self: self._parse_join_hint("MERGE"), - "SHUFFLEMERGE": lambda self: self._parse_join_hint("SHUFFLEMERGE"), - "MERGEJOIN": lambda self: self._parse_join_hint("MERGEJOIN"), - "SHUFFLE_HASH": lambda self: self._parse_join_hint("SHUFFLE_HASH"), - "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), - } - - def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]: - return self._match_text_seq("DROP", "COLUMNS") and self.expression( - exp.Drop, this=self._parse_schema(), kind="COLUMNS" - ) - - def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: - if len(aggregations) == 1: - return [] - return pivot_column_names(aggregations, dialect="spark") - - class Generator(Hive.Generator): - QUERY_HINTS = True - NVL2_SUPPORTED = True - CAN_IMPLEMENT_ARRAY_ANY = True - - PROPERTIES_LOCATION = { - **Hive.Generator.PROPERTIES_LOCATION, - exp.EngineProperty: exp.Properties.Location.UNSUPPORTED, - exp.AutoIncrementProperty: exp.Properties.Location.UNSUPPORTED, - exp.CharacterSetProperty: exp.Properties.Location.UNSUPPORTED, - exp.CollateProperty: exp.Properties.Location.UNSUPPORTED, - } - - TRANSFORMS = { - **Hive.Generator.TRANSFORMS, - exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), - exp.ArraySum: lambda self, - e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", - exp.ArrayToString: rename_func("ARRAY_JOIN"), - exp.AtTimeZone: lambda self, e: self.func( - "FROM_UTC_TIMESTAMP", e.this, e.args.get("zone") - ), - exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), - exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), - exp.Create: preprocess( - [ - remove_unique_constraints, - lambda e: ctas_with_tmp_tables_to_create_tmp_view( - e, temporary_storage_provider - ), - move_schema_columns_to_partitioned_by, - ] - ), - exp.DateFromParts: rename_func("MAKE_DATE"), - exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, unit_to_str(e)), - exp.DayOfMonth: rename_func("DAYOFMONTH"), - exp.DayOfWeek: rename_func("DAYOFWEEK"), - # (DAY_OF_WEEK(datetime) % 7) + 1 is equivalent to DAYOFWEEK_ISO(datetime) - exp.DayOfWeekIso: lambda self, e: f"(({self.func('DAYOFWEEK', e.this)} % 7) + 1)", - exp.DayOfYear: rename_func("DAYOFYEAR"), - exp.FileFormatProperty: lambda self, e: f"USING {e.name.upper()}", - exp.From: transforms.preprocess([_unalias_pivot]), - exp.FromTimeZone: lambda self, e: self.func( - "TO_UTC_TIMESTAMP", e.this, e.args.get("zone") - ), - exp.LogicalAnd: rename_func("BOOL_AND"), - exp.LogicalOr: rename_func("BOOL_OR"), - exp.Map: _map_sql, - exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]), - exp.Reduce: rename_func("AGGREGATE"), - exp.RegexpReplace: lambda self, e: self.func( - "REGEXP_REPLACE", - e.this, - e.expression, - e.args["replacement"], - e.args.get("position"), - ), - exp.Select: transforms.preprocess( - [ - transforms.eliminate_qualify, - transforms.eliminate_distinct_on, - transforms.unnest_to_explode, - transforms.any_to_exists, - ] - ), - exp.StrToDate: _str_to_date, - exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), - exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this), - exp.Trim: trim_sql, - exp.UnixToTime: _unix_to_time_sql, - exp.VariancePop: rename_func("VAR_POP"), - exp.WeekOfYear: rename_func("WEEKOFYEAR"), - exp.WithinGroup: transforms.preprocess( - [transforms.remove_within_group_for_percentiles] - ), - } - TRANSFORMS.pop(exp.ArraySort) - TRANSFORMS.pop(exp.ILike) - TRANSFORMS.pop(exp.Left) - TRANSFORMS.pop(exp.MonthsBetween) - TRANSFORMS.pop(exp.Right) - - WRAP_DERIVED_VALUES = False - CREATE_FUNCTION_RETURN_AS = False - - def struct_sql(self, expression: exp.Struct) -> str: - from sqlglot.generator import Generator - - return Generator.struct_sql(self, expression) - - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - arg = expression.this - is_json_extract = isinstance( - arg, (exp.JSONExtract, exp.JSONExtractScalar) - ) and not arg.args.get("variant_extract") - - # We can't use a non-nested type (eg. STRING) as a schema - if expression.to.args.get("nested") and (is_parse_json(arg) or is_json_extract): - schema = f"'{self.sql(expression, 'to')}'" - return self.func("FROM_JSON", arg if is_json_extract else arg.this, schema) - - if is_parse_json(expression): - return self.func("TO_JSON", arg) - - return super(Hive.Generator, self).cast_sql(expression, safe_prefix=safe_prefix) diff --git a/altimate_packages/sqlglot/dialects/sqlite.py b/altimate_packages/sqlglot/dialects/sqlite.py deleted file mode 100644 index 701b880b9..000000000 --- a/altimate_packages/sqlglot/dialects/sqlite.py +++ /dev/null @@ -1,320 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - Dialect, - NormalizationStrategy, - any_value_to_max_sql, - arrow_json_extract_sql, - concat_to_dpipe_sql, - count_if_to_sum, - no_ilike_sql, - no_pivot_sql, - no_tablesample_sql, - no_trycast_sql, - rename_func, - strposition_sql, -) -from sqlglot.generator import unsupported_args -from sqlglot.tokens import TokenType - - -def _build_strftime(args: t.List) -> exp.Anonymous | exp.TimeToStr: - if len(args) == 1: - args.append(exp.CurrentTimestamp()) - if len(args) == 2: - return exp.TimeToStr(this=exp.TsOrDsToTimestamp(this=args[1]), format=args[0]) - return exp.Anonymous(this="STRFTIME", expressions=args) - - -def _transform_create(expression: exp.Expression) -> exp.Expression: - """Move primary key to a column and enforce auto_increment on primary keys.""" - schema = expression.this - - if isinstance(expression, exp.Create) and isinstance(schema, exp.Schema): - defs = {} - primary_key = None - - for e in schema.expressions: - if isinstance(e, exp.ColumnDef): - defs[e.name] = e - elif isinstance(e, exp.PrimaryKey): - primary_key = e - - if primary_key and len(primary_key.expressions) == 1: - column = defs[primary_key.expressions[0].name] - column.append( - "constraints", exp.ColumnConstraint(kind=exp.PrimaryKeyColumnConstraint()) - ) - schema.expressions.remove(primary_key) - else: - for column in defs.values(): - auto_increment = None - for constraint in column.constraints: - if isinstance(constraint.kind, exp.PrimaryKeyColumnConstraint): - break - if isinstance(constraint.kind, exp.AutoIncrementColumnConstraint): - auto_increment = constraint - if auto_increment: - column.constraints.remove(auto_increment) - - return expression - - -def _generated_to_auto_increment(expression: exp.Expression) -> exp.Expression: - if not isinstance(expression, exp.ColumnDef): - return expression - - generated = expression.find(exp.GeneratedAsIdentityColumnConstraint) - - if generated: - t.cast(exp.ColumnConstraint, generated.parent).pop() - - not_null = expression.find(exp.NotNullColumnConstraint) - if not_null: - t.cast(exp.ColumnConstraint, not_null.parent).pop() - - expression.append( - "constraints", exp.ColumnConstraint(kind=exp.AutoIncrementColumnConstraint()) - ) - - return expression - - -class SQLite(Dialect): - # https://sqlite.org/forum/forumpost/5e575586ac5c711b?raw - NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE - SUPPORTS_SEMI_ANTI_JOIN = False - TYPED_DIVISION = True - SAFE_DIVISION = True - - class Tokenizer(tokens.Tokenizer): - IDENTIFIERS = ['"', ("[", "]"), "`"] - HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")] - - NESTED_COMMENTS = False - - KEYWORDS = tokens.Tokenizer.KEYWORDS.copy() - KEYWORDS.pop("/*+") - - COMMANDS = {*tokens.Tokenizer.COMMANDS, TokenType.REPLACE} - - class Parser(parser.Parser): - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "EDITDIST3": exp.Levenshtein.from_arg_list, - "STRFTIME": _build_strftime, - "DATETIME": lambda args: exp.Anonymous(this="DATETIME", expressions=args), - "TIME": lambda args: exp.Anonymous(this="TIME", expressions=args), - } - - STRING_ALIASES = True - ALTER_RENAME_REQUIRES_COLUMN = False - - def _parse_unique(self) -> exp.UniqueColumnConstraint: - # Do not consume more tokens if UNIQUE is used as a standalone constraint, e.g: - # CREATE TABLE foo (bar TEXT UNIQUE REFERENCES baz ...) - if self._curr.text.upper() in self.CONSTRAINT_PARSERS: - return self.expression(exp.UniqueColumnConstraint) - - return super()._parse_unique() - - class Generator(generator.Generator): - JOIN_HINTS = False - TABLE_HINTS = False - QUERY_HINTS = False - NVL2_SUPPORTED = False - JSON_PATH_BRACKETED_KEY_SUPPORTED = False - SUPPORTS_CREATE_TABLE_LIKE = False - SUPPORTS_TABLE_ALIAS_COLUMNS = False - SUPPORTS_TO_NUMBER = False - SUPPORTS_WINDOW_EXCLUDE = True - EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE = False - SUPPORTS_MEDIAN = False - JSON_KEY_VALUE_PAIR_SEP = "," - - SUPPORTED_JSON_PATH_PARTS = { - exp.JSONPathKey, - exp.JSONPathRoot, - exp.JSONPathSubscript, - } - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.BOOLEAN: "INTEGER", - exp.DataType.Type.TINYINT: "INTEGER", - exp.DataType.Type.SMALLINT: "INTEGER", - exp.DataType.Type.INT: "INTEGER", - exp.DataType.Type.BIGINT: "INTEGER", - exp.DataType.Type.FLOAT: "REAL", - exp.DataType.Type.DOUBLE: "REAL", - exp.DataType.Type.DECIMAL: "REAL", - exp.DataType.Type.CHAR: "TEXT", - exp.DataType.Type.NCHAR: "TEXT", - exp.DataType.Type.VARCHAR: "TEXT", - exp.DataType.Type.NVARCHAR: "TEXT", - exp.DataType.Type.BINARY: "BLOB", - exp.DataType.Type.VARBINARY: "BLOB", - } - TYPE_MAPPING.pop(exp.DataType.Type.BLOB) - - TOKEN_MAPPING = { - TokenType.AUTO_INCREMENT: "AUTOINCREMENT", - } - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.AnyValue: any_value_to_max_sql, - exp.Chr: rename_func("CHAR"), - exp.Concat: concat_to_dpipe_sql, - exp.CountIf: count_if_to_sum, - exp.Create: transforms.preprocess([_transform_create]), - exp.CurrentDate: lambda *_: "CURRENT_DATE", - exp.CurrentTime: lambda *_: "CURRENT_TIME", - exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.ColumnDef: transforms.preprocess([_generated_to_auto_increment]), - exp.DateStrToDate: lambda self, e: self.sql(e, "this"), - exp.If: rename_func("IIF"), - exp.ILike: no_ilike_sql, - exp.JSONExtractScalar: arrow_json_extract_sql, - exp.Levenshtein: unsupported_args("ins_cost", "del_cost", "sub_cost", "max_dist")( - rename_func("EDITDIST3") - ), - exp.LogicalOr: rename_func("MAX"), - exp.LogicalAnd: rename_func("MIN"), - exp.Pivot: no_pivot_sql, - exp.Rand: rename_func("RANDOM"), - exp.Select: transforms.preprocess( - [ - transforms.eliminate_distinct_on, - transforms.eliminate_qualify, - transforms.eliminate_semi_and_anti_joins, - ] - ), - exp.StrPosition: lambda self, e: strposition_sql(self, e, func_name="INSTR"), - exp.TableSample: no_tablesample_sql, - exp.TimeStrToTime: lambda self, e: self.sql(e, "this"), - exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.args.get("format"), e.this), - exp.TryCast: no_trycast_sql, - exp.TsOrDsToTimestamp: lambda self, e: self.sql(e, "this"), - } - - # SQLite doesn't generally support CREATE TABLE .. properties - # https://www.sqlite.org/lang_createtable.html - PROPERTIES_LOCATION = { - prop: exp.Properties.Location.UNSUPPORTED - for prop in generator.Generator.PROPERTIES_LOCATION - } - - # There are a few exceptions (e.g. temporary tables) which are supported or - # can be transpiled to SQLite, so we explicitly override them accordingly - PROPERTIES_LOCATION[exp.LikeProperty] = exp.Properties.Location.POST_SCHEMA - PROPERTIES_LOCATION[exp.TemporaryProperty] = exp.Properties.Location.POST_CREATE - - LIMIT_FETCH = "LIMIT" - - def jsonextract_sql(self, expression: exp.JSONExtract) -> str: - if expression.expressions: - return self.function_fallback_sql(expression) - return arrow_json_extract_sql(self, expression) - - def dateadd_sql(self, expression: exp.DateAdd) -> str: - modifier = expression.expression - modifier = modifier.name if modifier.is_string else self.sql(modifier) - unit = expression.args.get("unit") - modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'" - return self.func("DATE", expression.this, modifier) - - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - if expression.is_type("date"): - return self.func("DATE", expression.this) - - return super().cast_sql(expression) - - def generateseries_sql(self, expression: exp.GenerateSeries) -> str: - parent = expression.parent - alias = parent and parent.args.get("alias") - - if isinstance(alias, exp.TableAlias) and alias.columns: - column_alias = alias.columns[0] - alias.set("columns", None) - sql = self.sql( - exp.select(exp.alias_("value", column_alias)).from_(expression).subquery() - ) - else: - sql = self.function_fallback_sql(expression) - - return sql - - def datediff_sql(self, expression: exp.DateDiff) -> str: - unit = expression.args.get("unit") - unit = unit.name.upper() if unit else "DAY" - - sql = f"(JULIANDAY({self.sql(expression, 'this')}) - JULIANDAY({self.sql(expression, 'expression')}))" - - if unit == "MONTH": - sql = f"{sql} / 30.0" - elif unit == "YEAR": - sql = f"{sql} / 365.0" - elif unit == "HOUR": - sql = f"{sql} * 24.0" - elif unit == "MINUTE": - sql = f"{sql} * 1440.0" - elif unit == "SECOND": - sql = f"{sql} * 86400.0" - elif unit == "MILLISECOND": - sql = f"{sql} * 86400000.0" - elif unit == "MICROSECOND": - sql = f"{sql} * 86400000000.0" - elif unit == "NANOSECOND": - sql = f"{sql} * 8640000000000.0" - else: - self.unsupported(f"DATEDIFF unsupported for '{unit}'.") - - return f"CAST({sql} AS INTEGER)" - - # https://www.sqlite.org/lang_aggfunc.html#group_concat - def groupconcat_sql(self, expression: exp.GroupConcat) -> str: - this = expression.this - distinct = expression.find(exp.Distinct) - - if distinct: - this = distinct.expressions[0] - distinct_sql = "DISTINCT " - else: - distinct_sql = "" - - if isinstance(expression.this, exp.Order): - self.unsupported("SQLite GROUP_CONCAT doesn't support ORDER BY.") - if expression.this.this and not distinct: - this = expression.this.this - - separator = expression.args.get("separator") - return f"GROUP_CONCAT({distinct_sql}{self.format_args(this, separator)})" - - def least_sql(self, expression: exp.Least) -> str: - if len(expression.expressions) > 1: - return rename_func("MIN")(self, expression) - - return self.sql(expression, "this") - - def transaction_sql(self, expression: exp.Transaction) -> str: - this = expression.this - this = f" {this}" if this else "" - return f"BEGIN{this} TRANSACTION" - - def isascii_sql(self, expression: exp.IsAscii) -> str: - return f"(NOT {self.sql(expression.this)} GLOB CAST(x'2a5b5e012d7f5d2a' AS TEXT))" - - @unsupported_args("this") - def currentschema_sql(self, expression: exp.CurrentSchema) -> str: - return "'main'" - - def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str: - self.unsupported("SQLite does not support IGNORE NULLS.") - return self.sql(expression.this) - - def respectnulls_sql(self, expression: exp.RespectNulls) -> str: - return self.sql(expression.this) diff --git a/altimate_packages/sqlglot/dialects/starrocks.py b/altimate_packages/sqlglot/dialects/starrocks.py deleted file mode 100644 index 2079eea0a..000000000 --- a/altimate_packages/sqlglot/dialects/starrocks.py +++ /dev/null @@ -1,343 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp -from sqlglot.dialects.dialect import ( - approx_count_distinct_sql, - arrow_json_extract_sql, - build_timestamp_trunc, - rename_func, - unit_to_str, - inline_array_sql, - property_sql, -) -from sqlglot.dialects.mysql import MySQL -from sqlglot.helper import seq_get -from sqlglot.tokens import TokenType - - -# https://docs.starrocks.io/docs/sql-reference/sql-functions/spatial-functions/st_distance_sphere/ -def st_distance_sphere(self, expression: exp.StDistance) -> str: - point1 = expression.this - point2 = expression.expression - - point1_x = self.func("ST_X", point1) - point1_y = self.func("ST_Y", point1) - point2_x = self.func("ST_X", point2) - point2_y = self.func("ST_Y", point2) - - return self.func("ST_Distance_Sphere", point1_x, point1_y, point2_x, point2_y) - - -class StarRocks(MySQL): - STRICT_JSON_PATH_SYNTAX = False - - class Tokenizer(MySQL.Tokenizer): - KEYWORDS = { - **MySQL.Tokenizer.KEYWORDS, - "LARGEINT": TokenType.INT128, - } - - class Parser(MySQL.Parser): - FUNCTIONS = { - **MySQL.Parser.FUNCTIONS, - "DATE_TRUNC": build_timestamp_trunc, - "DATEDIFF": lambda args: exp.DateDiff( - this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") - ), - "DATE_DIFF": lambda args: exp.DateDiff( - this=seq_get(args, 1), expression=seq_get(args, 2), unit=seq_get(args, 0) - ), - "REGEXP": exp.RegexpLike.from_arg_list, - } - - PROPERTY_PARSERS = { - **MySQL.Parser.PROPERTY_PARSERS, - "UNIQUE": lambda self: self._parse_composite_key_property(exp.UniqueKeyProperty), - "PROPERTIES": lambda self: self._parse_wrapped_properties(), - "PARTITION BY": lambda self: self._parse_partition_by_opt_range(), - } - - def _parse_create(self) -> exp.Create | exp.Command: - create = super()._parse_create() - - # Starrocks' primary key is defined outside of the schema, so we need to move it there - # https://docs.starrocks.io/docs/table_design/table_types/primary_key_table/#usage - if isinstance(create, exp.Create) and isinstance(create.this, exp.Schema): - props = create.args.get("properties") - if props: - primary_key = props.find(exp.PrimaryKey) - if primary_key: - create.this.append("expressions", primary_key.pop()) - - return create - - def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: - unnest = super()._parse_unnest(with_alias=with_alias) - - if unnest: - alias = unnest.args.get("alias") - - if not alias: - # Starrocks defaults to naming the table alias as "unnest" - alias = exp.TableAlias( - this=exp.to_identifier("unnest"), columns=[exp.to_identifier("unnest")] - ) - unnest.set("alias", alias) - elif not alias.args.get("columns"): - # Starrocks defaults to naming the UNNEST column as "unnest" - # if it's not otherwise specified - alias.set("columns", [exp.to_identifier("unnest")]) - - return unnest - - def _parse_partitioning_granularity_dynamic(self) -> exp.PartitionByRangePropertyDynamic: - self._match_text_seq("START") - start = self._parse_wrapped(self._parse_string) - self._match_text_seq("END") - end = self._parse_wrapped(self._parse_string) - self._match_text_seq("EVERY") - every = self._parse_wrapped(lambda: self._parse_interval() or self._parse_number()) - return self.expression( - exp.PartitionByRangePropertyDynamic, start=start, end=end, every=every - ) - - def _parse_partition_by_opt_range( - self, - ) -> exp.PartitionedByProperty | exp.PartitionByRangeProperty: - if self._match_text_seq("RANGE"): - partition_expressions = self._parse_wrapped_id_vars() - create_expressions = self._parse_wrapped_csv( - self._parse_partitioning_granularity_dynamic - ) - return self.expression( - exp.PartitionByRangeProperty, - partition_expressions=partition_expressions, - create_expressions=create_expressions, - ) - return super()._parse_partitioned_by() - - class Generator(MySQL.Generator): - EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE = False - JSON_TYPE_REQUIRED_FOR_EXTRACTION = False - VARCHAR_REQUIRES_SIZE = False - PARSE_JSON_NAME: t.Optional[str] = "PARSE_JSON" - WITH_PROPERTIES_PREFIX = "PROPERTIES" - - CAST_MAPPING = {} - - TYPE_MAPPING = { - **MySQL.Generator.TYPE_MAPPING, - exp.DataType.Type.INT128: "LARGEINT", - exp.DataType.Type.TEXT: "STRING", - exp.DataType.Type.TIMESTAMP: "DATETIME", - exp.DataType.Type.TIMESTAMPTZ: "DATETIME", - } - - PROPERTIES_LOCATION = { - **MySQL.Generator.PROPERTIES_LOCATION, - exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA, - exp.UniqueKeyProperty: exp.Properties.Location.POST_SCHEMA, - exp.PartitionByRangeProperty: exp.Properties.Location.POST_SCHEMA, - } - - TRANSFORMS = { - **MySQL.Generator.TRANSFORMS, - exp.Array: inline_array_sql, - exp.ArrayAgg: rename_func("ARRAY_AGG"), - exp.ArrayFilter: rename_func("ARRAY_FILTER"), - exp.ArrayToString: rename_func("ARRAY_JOIN"), - exp.ApproxDistinct: approx_count_distinct_sql, - exp.DateDiff: lambda self, e: self.func( - "DATE_DIFF", unit_to_str(e), e.this, e.expression - ), - exp.JSONExtractScalar: arrow_json_extract_sql, - exp.JSONExtract: arrow_json_extract_sql, - exp.Property: property_sql, - exp.RegexpLike: rename_func("REGEXP"), - exp.StDistance: st_distance_sphere, - exp.StrToUnix: lambda self, e: self.func("UNIX_TIMESTAMP", e.this, self.format_time(e)), - exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this), - exp.TimeStrToDate: rename_func("TO_DATE"), - exp.UnixToStr: lambda self, e: self.func("FROM_UNIXTIME", e.this, self.format_time(e)), - exp.UnixToTime: rename_func("FROM_UNIXTIME"), - } - - TRANSFORMS.pop(exp.DateTrunc) - - # https://docs.starrocks.io/docs/sql-reference/sql-statements/keywords/#reserved-keywords - RESERVED_KEYWORDS = { - "add", - "all", - "alter", - "analyze", - "and", - "array", - "as", - "asc", - "between", - "bigint", - "bitmap", - "both", - "by", - "case", - "char", - "character", - "check", - "collate", - "column", - "compaction", - "convert", - "create", - "cross", - "cube", - "current_date", - "current_role", - "current_time", - "current_timestamp", - "current_user", - "database", - "databases", - "decimal", - "decimalv2", - "decimal32", - "decimal64", - "decimal128", - "default", - "deferred", - "delete", - "dense_rank", - "desc", - "describe", - "distinct", - "double", - "drop", - "dual", - "else", - "except", - "exists", - "explain", - "false", - "first_value", - "float", - "for", - "force", - "from", - "full", - "function", - "grant", - "group", - "grouping", - "grouping_id", - "groups", - "having", - "hll", - "host", - "if", - "ignore", - "immediate", - "in", - "index", - "infile", - "inner", - "insert", - "int", - "integer", - "intersect", - "into", - "is", - "join", - "json", - "key", - "keys", - "kill", - "lag", - "largeint", - "last_value", - "lateral", - "lead", - "left", - "like", - "limit", - "load", - "localtime", - "localtimestamp", - "maxvalue", - "minus", - "mod", - "not", - "ntile", - "null", - "on", - "or", - "order", - "outer", - "outfile", - "over", - "partition", - "percentile", - "primary", - "procedure", - "qualify", - "range", - "rank", - "read", - "regexp", - "release", - "rename", - "replace", - "revoke", - "right", - "rlike", - "row", - "row_number", - "rows", - "schema", - "schemas", - "select", - "set", - "set_var", - "show", - "smallint", - "system", - "table", - "terminated", - "text", - "then", - "tinyint", - "to", - "true", - "union", - "unique", - "unsigned", - "update", - "use", - "using", - "values", - "varchar", - "when", - "where", - "with", - } - - def create_sql(self, expression: exp.Create) -> str: - # Starrocks' primary key is defined outside of the schema, so we need to move it there - schema = expression.this - if isinstance(schema, exp.Schema): - primary_key = schema.find(exp.PrimaryKey) - - if primary_key: - props = expression.args.get("properties") - - if not props: - props = exp.Properties(expressions=[]) - expression.set("properties", props) - - # Verify if the first one is an engine property. Is true then insert it after the engine, - # otherwise insert it at the beginning - engine = props.find(exp.EngineProperty) - engine_index = (engine.index or 0) if engine else -1 - props.set("expressions", primary_key.pop(), engine_index + 1, overwrite=False) - - return super().create_sql(expression) diff --git a/altimate_packages/sqlglot/dialects/tableau.py b/altimate_packages/sqlglot/dialects/tableau.py deleted file mode 100644 index 61b2e2c14..000000000 --- a/altimate_packages/sqlglot/dialects/tableau.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import Dialect, rename_func, strposition_sql as _strposition_sql -from sqlglot.helper import seq_get - - -class Tableau(Dialect): - LOG_BASE_FIRST = False - - class Tokenizer(tokens.Tokenizer): - IDENTIFIERS = [("[", "]")] - QUOTES = ["'", '"'] - - class Generator(generator.Generator): - JOIN_HINTS = False - TABLE_HINTS = False - QUERY_HINTS = False - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.Coalesce: rename_func("IFNULL"), - exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), - } - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, - } - - def if_sql(self, expression: exp.If) -> str: - this = self.sql(expression, "this") - true = self.sql(expression, "true") - false = self.sql(expression, "false") - return f"IF {this} THEN {true} ELSE {false} END" - - def count_sql(self, expression: exp.Count) -> str: - this = expression.this - if isinstance(this, exp.Distinct): - return self.func("COUNTD", *this.expressions) - return self.func("COUNT", this) - - def strposition_sql(self, expression: exp.StrPosition) -> str: - has_occurrence = "occurrence" in expression.args - return _strposition_sql( - self, - expression, - func_name="FINDNTH" if has_occurrence else "FIND", - supports_occurrence=has_occurrence, - ) - - class Parser(parser.Parser): - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), - "FIND": exp.StrPosition.from_arg_list, - "FINDNTH": lambda args: exp.StrPosition( - this=seq_get(args, 0), substr=seq_get(args, 1), occurrence=seq_get(args, 2) - ), - } - NO_PAREN_IF_COMMANDS = False diff --git a/altimate_packages/sqlglot/dialects/teradata.py b/altimate_packages/sqlglot/dialects/teradata.py deleted file mode 100644 index d69a2aa1a..000000000 --- a/altimate_packages/sqlglot/dialects/teradata.py +++ /dev/null @@ -1,356 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - Dialect, - max_or_greatest, - min_or_least, - rename_func, - strposition_sql, - to_number_with_nls_param, -) -from sqlglot.helper import seq_get -from sqlglot.tokens import TokenType - - -def _date_add_sql( - kind: t.Literal["+", "-"], -) -> t.Callable[[Teradata.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: Teradata.Generator, expression: exp.DateAdd | exp.DateSub) -> str: - this = self.sql(expression, "this") - unit = expression.args.get("unit") - value = self._simplify_unless_literal(expression.expression) - - if not isinstance(value, exp.Literal): - self.unsupported("Cannot add non literal") - - if isinstance(value, exp.Neg): - kind_to_op = {"+": "-", "-": "+"} - value = exp.Literal.string(value.this.to_py()) - else: - kind_to_op = {"+": "+", "-": "-"} - value.set("is_string", True) - - return f"{this} {kind_to_op[kind]} {self.sql(exp.Interval(this=value, unit=unit))}" - - return func - - -class Teradata(Dialect): - SUPPORTS_SEMI_ANTI_JOIN = False - TYPED_DIVISION = True - - TIME_MAPPING = { - "YY": "%y", - "Y4": "%Y", - "YYYY": "%Y", - "M4": "%B", - "M3": "%b", - "M": "%-M", - "MI": "%M", - "MM": "%m", - "MMM": "%b", - "MMMM": "%B", - "D": "%-d", - "DD": "%d", - "D3": "%j", - "DDD": "%j", - "H": "%-H", - "HH": "%H", - "HH24": "%H", - "S": "%-S", - "SS": "%S", - "SSSSSS": "%f", - "E": "%a", - "EE": "%a", - "E3": "%a", - "E4": "%A", - "EEE": "%a", - "EEEE": "%A", - } - - class Tokenizer(tokens.Tokenizer): - # Tested each of these and they work, although there is no - # Teradata documentation explicitly mentioning them. - HEX_STRINGS = [("X'", "'"), ("x'", "'"), ("0x", "")] - # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Comparison-Operators-and-Functions/Comparison-Operators/ANSI-Compliance - # https://docs.teradata.com/r/SQL-Functions-Operators-Expressions-and-Predicates/June-2017/Arithmetic-Trigonometric-Hyperbolic-Operators/Functions - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - "**": TokenType.DSTAR, - "^=": TokenType.NEQ, - "BYTEINT": TokenType.SMALLINT, - "COLLECT": TokenType.COMMAND, - "DEL": TokenType.DELETE, - "EQ": TokenType.EQ, - "GE": TokenType.GTE, - "GT": TokenType.GT, - "HELP": TokenType.COMMAND, - "INS": TokenType.INSERT, - "LE": TokenType.LTE, - "LT": TokenType.LT, - "MINUS": TokenType.EXCEPT, - "MOD": TokenType.MOD, - "NE": TokenType.NEQ, - "NOT=": TokenType.NEQ, - "SAMPLE": TokenType.TABLE_SAMPLE, - "SEL": TokenType.SELECT, - "ST_GEOMETRY": TokenType.GEOMETRY, - "TOP": TokenType.TOP, - "UPD": TokenType.UPDATE, - } - KEYWORDS.pop("/*+") - - # Teradata does not support % as a modulo operator - SINGLE_TOKENS = {**tokens.Tokenizer.SINGLE_TOKENS} - SINGLE_TOKENS.pop("%") - - class Parser(parser.Parser): - TABLESAMPLE_CSV = True - VALUES_FOLLOWED_BY_PAREN = False - - CHARSET_TRANSLATORS = { - "GRAPHIC_TO_KANJISJIS", - "GRAPHIC_TO_LATIN", - "GRAPHIC_TO_UNICODE", - "GRAPHIC_TO_UNICODE_PadSpace", - "KANJI1_KanjiEBCDIC_TO_UNICODE", - "KANJI1_KanjiEUC_TO_UNICODE", - "KANJI1_KANJISJIS_TO_UNICODE", - "KANJI1_SBC_TO_UNICODE", - "KANJISJIS_TO_GRAPHIC", - "KANJISJIS_TO_LATIN", - "KANJISJIS_TO_UNICODE", - "LATIN_TO_GRAPHIC", - "LATIN_TO_KANJISJIS", - "LATIN_TO_UNICODE", - "LOCALE_TO_UNICODE", - "UNICODE_TO_GRAPHIC", - "UNICODE_TO_GRAPHIC_PadGraphic", - "UNICODE_TO_GRAPHIC_VarGraphic", - "UNICODE_TO_KANJI1_KanjiEBCDIC", - "UNICODE_TO_KANJI1_KanjiEUC", - "UNICODE_TO_KANJI1_KANJISJIS", - "UNICODE_TO_KANJI1_SBC", - "UNICODE_TO_KANJISJIS", - "UNICODE_TO_LATIN", - "UNICODE_TO_LOCALE", - "UNICODE_TO_UNICODE_FoldSpace", - "UNICODE_TO_UNICODE_Fullwidth", - "UNICODE_TO_UNICODE_Halfwidth", - "UNICODE_TO_UNICODE_NFC", - "UNICODE_TO_UNICODE_NFD", - "UNICODE_TO_UNICODE_NFKC", - "UNICODE_TO_UNICODE_NFKD", - } - - FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS} - FUNC_TOKENS.remove(TokenType.REPLACE) - - STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, - TokenType.DATABASE: lambda self: self.expression( - exp.Use, this=self._parse_table(schema=False) - ), - TokenType.REPLACE: lambda self: self._parse_create(), - } - - FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, - # https://docs.teradata.com/r/SQL-Functions-Operators-Expressions-and-Predicates/June-2017/Data-Type-Conversions/TRYCAST - "TRYCAST": parser.Parser.FUNCTION_PARSERS["TRY_CAST"], - "RANGE_N": lambda self: self._parse_rangen(), - "TRANSLATE": lambda self: self._parse_translate(), - } - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "CARDINALITY": exp.ArraySize.from_arg_list, - "RANDOM": lambda args: exp.Rand(lower=seq_get(args, 0), upper=seq_get(args, 1)), - } - - EXPONENT = { - TokenType.DSTAR: exp.Pow, - } - - def _parse_translate(self) -> exp.TranslateCharacters: - this = self._parse_assignment() - self._match(TokenType.USING) - self._match_texts(self.CHARSET_TRANSLATORS) - - return self.expression( - exp.TranslateCharacters, - this=this, - expression=self._prev.text.upper(), - with_error=self._match_text_seq("WITH", "ERROR"), - ) - - # FROM before SET in Teradata UPDATE syntax - # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause - def _parse_update(self) -> exp.Update: - return self.expression( - exp.Update, - **{ # type: ignore - "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), - "from": self._parse_from(joins=True), - "expressions": self._match(TokenType.SET) - and self._parse_csv(self._parse_equality), - "where": self._parse_where(), - }, - ) - - def _parse_rangen(self): - this = self._parse_id_var() - self._match(TokenType.BETWEEN) - - expressions = self._parse_csv(self._parse_assignment) - each = self._match_text_seq("EACH") and self._parse_assignment() - - return self.expression(exp.RangeN, this=this, expressions=expressions, each=each) - - def _parse_index_params(self) -> exp.IndexParameters: - this = super()._parse_index_params() - - if this.args.get("on"): - this.set("on", None) - self._retreat(self._index - 2) - return this - - class Generator(generator.Generator): - LIMIT_IS_TOP = True - JOIN_HINTS = False - TABLE_HINTS = False - QUERY_HINTS = False - TABLESAMPLE_KEYWORDS = "SAMPLE" - LAST_DAY_SUPPORTS_DATE_PART = False - CAN_IMPLEMENT_ARRAY_ANY = True - TZ_TO_WITH_TIME_ZONE = True - ARRAY_SIZE_NAME = "CARDINALITY" - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.GEOMETRY: "ST_GEOMETRY", - exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", - exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", - } - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.OnCommitProperty: exp.Properties.Location.POST_INDEX, - exp.PartitionedByProperty: exp.Properties.Location.POST_EXPRESSION, - exp.StabilityProperty: exp.Properties.Location.POST_CREATE, - } - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.ArgMax: rename_func("MAX_BY"), - exp.ArgMin: rename_func("MIN_BY"), - exp.Max: max_or_greatest, - exp.Min: min_or_least, - exp.Pow: lambda self, e: self.binary(e, "**"), - exp.Rand: lambda self, e: self.func("RANDOM", e.args.get("lower"), e.args.get("upper")), - exp.Select: transforms.preprocess( - [transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins] - ), - exp.StrPosition: lambda self, e: ( - strposition_sql( - self, e, func_name="INSTR", supports_position=True, supports_occurrence=True - ) - ), - exp.StrToDate: lambda self, - e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})", - exp.ToChar: lambda self, e: self.function_fallback_sql(e), - exp.ToNumber: to_number_with_nls_param, - exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}", - exp.DateAdd: _date_add_sql("+"), - exp.DateSub: _date_add_sql("-"), - exp.Quarter: lambda self, e: self.sql(exp.Extract(this="QUARTER", expression=e.this)), - } - - def currenttimestamp_sql(self, expression: exp.CurrentTimestamp) -> str: - prefix, suffix = ("(", ")") if expression.this else ("", "") - return self.func("CURRENT_TIMESTAMP", expression.this, prefix=prefix, suffix=suffix) - - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - if expression.to.this == exp.DataType.Type.UNKNOWN and expression.args.get("format"): - # We don't actually want to print the unknown type in CAST( AS FORMAT ) - expression.to.pop() - - return super().cast_sql(expression, safe_prefix=safe_prefix) - - def trycast_sql(self, expression: exp.TryCast) -> str: - return self.cast_sql(expression, safe_prefix="TRY") - - def tablesample_sql( - self, - expression: exp.TableSample, - tablesample_keyword: t.Optional[str] = None, - ) -> str: - return f"{self.sql(expression, 'this')} SAMPLE {self.expressions(expression)}" - - def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str: - return f"PARTITION BY {self.sql(expression, 'this')}" - - # FROM before SET in Teradata UPDATE syntax - # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause - def update_sql(self, expression: exp.Update) -> str: - this = self.sql(expression, "this") - from_sql = self.sql(expression, "from") - set_sql = self.expressions(expression, flat=True) - where_sql = self.sql(expression, "where") - sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}" - return self.prepend_ctes(expression, sql) - - def mod_sql(self, expression: exp.Mod) -> str: - return self.binary(expression, "MOD") - - def datatype_sql(self, expression: exp.DataType) -> str: - type_sql = super().datatype_sql(expression) - prefix_sql = expression.args.get("prefix") - return f"SYSUDTLIB.{type_sql}" if prefix_sql else type_sql - - def rangen_sql(self, expression: exp.RangeN) -> str: - this = self.sql(expression, "this") - expressions_sql = self.expressions(expression) - each_sql = self.sql(expression, "each") - each_sql = f" EACH {each_sql}" if each_sql else "" - - return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})" - - def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: - kind = self.sql(expression, "kind").upper() - if kind == "TABLE" and locations.get(exp.Properties.Location.POST_NAME): - this_name = self.sql(expression.this, "this") - this_properties = self.properties( - exp.Properties(expressions=locations[exp.Properties.Location.POST_NAME]), - wrapped=False, - prefix=",", - ) - this_schema = self.schema_columns_sql(expression.this) - return f"{this_name}{this_properties}{self.sep()}{this_schema}" - - return super().createable_sql(expression, locations) - - def extract_sql(self, expression: exp.Extract) -> str: - this = self.sql(expression, "this") - if this.upper() != "QUARTER": - return super().extract_sql(expression) - - to_char = exp.func("to_char", expression.expression, exp.Literal.string("Q")) - return self.sql(exp.cast(to_char, exp.DataType.Type.INT)) - - def interval_sql(self, expression: exp.Interval) -> str: - multiplier = 0 - unit = expression.text("unit") - - if unit.startswith("WEEK"): - multiplier = 7 - elif unit.startswith("QUARTER"): - multiplier = 90 - - if multiplier: - return f"({multiplier} * {super().interval_sql(exp.Interval(this=expression.this, unit=exp.var('DAY')))})" - - return super().interval_sql(expression) diff --git a/altimate_packages/sqlglot/dialects/trino.py b/altimate_packages/sqlglot/dialects/trino.py deleted file mode 100644 index 42cfb04f4..000000000 --- a/altimate_packages/sqlglot/dialects/trino.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import annotations - -from sqlglot import exp, parser, transforms -from sqlglot.dialects.dialect import ( - merge_without_target_sql, - trim_sql, - timestrtotime_sql, - groupconcat_sql, -) -from sqlglot.dialects.presto import amend_exploded_column_table, Presto -from sqlglot.tokens import TokenType -import typing as t - - -class Trino(Presto): - SUPPORTS_USER_DEFINED_TYPES = False - LOG_BASE_FIRST = True - - class Parser(Presto.Parser): - FUNCTION_PARSERS = { - **Presto.Parser.FUNCTION_PARSERS, - "TRIM": lambda self: self._parse_trim(), - "JSON_QUERY": lambda self: self._parse_json_query(), - "LISTAGG": lambda self: self._parse_string_agg(), - } - - JSON_QUERY_OPTIONS: parser.OPTIONS_TYPE = { - **dict.fromkeys( - ("WITH", "WITHOUT"), - ( - ("WRAPPER"), - ("ARRAY", "WRAPPER"), - ("CONDITIONAL", "WRAPPER"), - ("CONDITIONAL", "ARRAY", "WRAPPED"), - ("UNCONDITIONAL", "WRAPPER"), - ("UNCONDITIONAL", "ARRAY", "WRAPPER"), - ), - ), - } - - def _parse_json_query_quote(self) -> t.Optional[exp.JSONExtractQuote]: - if not ( - self._match_text_seq("KEEP", "QUOTES") or self._match_text_seq("OMIT", "QUOTES") - ): - return None - - return self.expression( - exp.JSONExtractQuote, - option=self._tokens[self._index - 2].text.upper(), - scalar=self._match_text_seq("ON", "SCALAR", "STRING"), - ) - - def _parse_json_query(self) -> exp.JSONExtract: - return self.expression( - exp.JSONExtract, - this=self._parse_bitwise(), - expression=self._match(TokenType.COMMA) and self._parse_bitwise(), - option=self._parse_var_from_options(self.JSON_QUERY_OPTIONS, raise_unmatched=False), - json_query=True, - quote=self._parse_json_query_quote(), - on_condition=self._parse_on_condition(), - ) - - class Generator(Presto.Generator): - PROPERTIES_LOCATION = { - **Presto.Generator.PROPERTIES_LOCATION, - exp.LocationProperty: exp.Properties.Location.POST_WITH, - } - - TRANSFORMS = { - **Presto.Generator.TRANSFORMS, - exp.ArraySum: lambda self, - e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", - exp.ArrayUniqueAgg: lambda self, e: f"ARRAY_AGG(DISTINCT {self.sql(e, 'this')})", - exp.GroupConcat: lambda self, e: groupconcat_sql(self, e, on_overflow=True), - exp.LocationProperty: lambda self, e: self.property_sql(e), - exp.Merge: merge_without_target_sql, - exp.Select: transforms.preprocess( - [ - transforms.eliminate_qualify, - transforms.eliminate_distinct_on, - transforms.explode_projection_to_unnest(1), - transforms.eliminate_semi_and_anti_joins, - amend_exploded_column_table, - ] - ), - exp.TimeStrToTime: lambda self, e: timestrtotime_sql(self, e, include_precision=True), - exp.Trim: trim_sql, - } - - SUPPORTED_JSON_PATH_PARTS = { - exp.JSONPathKey, - exp.JSONPathRoot, - exp.JSONPathSubscript, - } - - def jsonextract_sql(self, expression: exp.JSONExtract) -> str: - if not expression.args.get("json_query"): - return super().jsonextract_sql(expression) - - json_path = self.sql(expression, "expression") - option = self.sql(expression, "option") - option = f" {option}" if option else "" - - quote = self.sql(expression, "quote") - quote = f" {quote}" if quote else "" - - on_condition = self.sql(expression, "on_condition") - on_condition = f" {on_condition}" if on_condition else "" - - return self.func( - "JSON_QUERY", - expression.this, - json_path + option + quote + on_condition, - ) diff --git a/altimate_packages/sqlglot/dialects/tsql.py b/altimate_packages/sqlglot/dialects/tsql.py deleted file mode 100644 index 9e1dd97b2..000000000 --- a/altimate_packages/sqlglot/dialects/tsql.py +++ /dev/null @@ -1,1403 +0,0 @@ -from __future__ import annotations - -import datetime -import re -import typing as t -from functools import partial, reduce - -from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import ( - Dialect, - NormalizationStrategy, - any_value_to_max_sql, - build_date_delta, - date_delta_sql, - datestrtodate_sql, - generatedasidentitycolumnconstraint_sql, - max_or_greatest, - min_or_least, - rename_func, - strposition_sql, - timestrtotime_sql, - trim_sql, -) -from sqlglot.helper import seq_get -from sqlglot.parser import build_coalesce -from sqlglot.time import format_time -from sqlglot.tokens import TokenType - -if t.TYPE_CHECKING: - from sqlglot._typing import E - -FULL_FORMAT_TIME_MAPPING = { - "weekday": "%A", - "dw": "%A", - "w": "%A", - "month": "%B", - "mm": "%B", - "m": "%B", -} - -DATE_DELTA_INTERVAL = { - "year": "year", - "yyyy": "year", - "yy": "year", - "quarter": "quarter", - "qq": "quarter", - "q": "quarter", - "month": "month", - "mm": "month", - "m": "month", - "week": "week", - "ww": "week", - "wk": "week", - "day": "day", - "dd": "day", - "d": "day", -} - - -DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})") - -# N = Numeric, C=Currency -TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} - -DEFAULT_START_DATE = datetime.date(1900, 1, 1) - -BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias} - -# Unsupported options: -# - OPTIMIZE FOR ( @variable_name { UNKNOWN | = } [ , ...n ] ) -# - TABLE HINT -OPTIONS: parser.OPTIONS_TYPE = { - **dict.fromkeys( - ( - "DISABLE_OPTIMIZED_PLAN_FORCING", - "FAST", - "IGNORE_NONCLUSTERED_COLUMNSTORE_INDEX", - "LABEL", - "MAXDOP", - "MAXRECURSION", - "MAX_GRANT_PERCENT", - "MIN_GRANT_PERCENT", - "NO_PERFORMANCE_SPOOL", - "QUERYTRACEON", - "RECOMPILE", - ), - tuple(), - ), - "CONCAT": ("UNION",), - "DISABLE": ("EXTERNALPUSHDOWN", "SCALEOUTEXECUTION"), - "EXPAND": ("VIEWS",), - "FORCE": ("EXTERNALPUSHDOWN", "ORDER", "SCALEOUTEXECUTION"), - "HASH": ("GROUP", "JOIN", "UNION"), - "KEEP": ("PLAN",), - "KEEPFIXED": ("PLAN",), - "LOOP": ("JOIN",), - "MERGE": ("JOIN", "UNION"), - "OPTIMIZE": (("FOR", "UNKNOWN"),), - "ORDER": ("GROUP",), - "PARAMETERIZATION": ("FORCED", "SIMPLE"), - "ROBUST": ("PLAN",), - "USE": ("PLAN",), -} - - -XML_OPTIONS: parser.OPTIONS_TYPE = { - **dict.fromkeys( - ( - "AUTO", - "EXPLICIT", - "TYPE", - ), - tuple(), - ), - "ELEMENTS": ( - "XSINIL", - "ABSENT", - ), - "BINARY": ("BASE64",), -} - - -OPTIONS_THAT_REQUIRE_EQUAL = ("MAX_GRANT_PERCENT", "MIN_GRANT_PERCENT", "LABEL") - - -def _build_formatted_time( - exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None -) -> t.Callable[[t.List], E]: - def _builder(args: t.List) -> E: - fmt = seq_get(args, 0) - if isinstance(fmt, exp.Expression): - fmt = exp.Literal.string( - format_time( - fmt.name.lower(), - ( - {**TSQL.TIME_MAPPING, **FULL_FORMAT_TIME_MAPPING} - if full_format_mapping - else TSQL.TIME_MAPPING - ), - ) - ) - - this = seq_get(args, 1) - if isinstance(this, exp.Expression): - this = exp.cast(this, exp.DataType.Type.DATETIME2) - - return exp_class(this=this, format=fmt) - - return _builder - - -def _build_format(args: t.List) -> exp.NumberToStr | exp.TimeToStr: - this = seq_get(args, 0) - fmt = seq_get(args, 1) - culture = seq_get(args, 2) - - number_fmt = fmt and (fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.name)) - - if number_fmt: - return exp.NumberToStr(this=this, format=fmt, culture=culture) - - if fmt: - fmt = exp.Literal.string( - format_time(fmt.name, TSQL.FORMAT_TIME_MAPPING) - if len(fmt.name) == 1 - else format_time(fmt.name, TSQL.TIME_MAPPING) - ) - - return exp.TimeToStr(this=this, format=fmt, culture=culture) - - -def _build_eomonth(args: t.List) -> exp.LastDay: - date = exp.TsOrDsToDate(this=seq_get(args, 0)) - month_lag = seq_get(args, 1) - - if month_lag is None: - this: exp.Expression = date - else: - unit = DATE_DELTA_INTERVAL.get("month") - this = exp.DateAdd(this=date, expression=month_lag, unit=unit and exp.var(unit)) - - return exp.LastDay(this=this) - - -def _build_hashbytes(args: t.List) -> exp.Expression: - kind, data = args - kind = kind.name.upper() if kind.is_string else "" - - if kind == "MD5": - args.pop(0) - return exp.MD5(this=data) - if kind in ("SHA", "SHA1"): - args.pop(0) - return exp.SHA(this=data) - if kind == "SHA2_256": - return exp.SHA2(this=data, length=exp.Literal.number(256)) - if kind == "SHA2_512": - return exp.SHA2(this=data, length=exp.Literal.number(512)) - - return exp.func("HASHBYTES", *args) - - -DATEPART_ONLY_FORMATS = {"DW", "WK", "HOUR", "QUARTER"} - - -def _format_sql(self: TSQL.Generator, expression: exp.NumberToStr | exp.TimeToStr) -> str: - fmt = expression.args["format"] - - if not isinstance(expression, exp.NumberToStr): - if fmt.is_string: - mapped_fmt = format_time(fmt.name, TSQL.INVERSE_TIME_MAPPING) - - name = (mapped_fmt or "").upper() - if name in DATEPART_ONLY_FORMATS: - return self.func("DATEPART", name, expression.this) - - fmt_sql = self.sql(exp.Literal.string(mapped_fmt)) - else: - fmt_sql = self.format_time(expression) or self.sql(fmt) - else: - fmt_sql = self.sql(fmt) - - return self.func("FORMAT", expression.this, fmt_sql, expression.args.get("culture")) - - -def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: - this = expression.this - distinct = expression.find(exp.Distinct) - if distinct: - # exp.Distinct can appear below an exp.Order or an exp.GroupConcat expression - self.unsupported("T-SQL STRING_AGG doesn't support DISTINCT.") - this = distinct.pop().expressions[0] - - order = "" - if isinstance(expression.this, exp.Order): - if expression.this.this: - this = expression.this.this.pop() - # Order has a leading space - order = f" WITHIN GROUP ({self.sql(expression.this)[1:]})" - - separator = expression.args.get("separator") or exp.Literal.string(",") - return f"STRING_AGG({self.format_args(this, separator)}){order}" - - -def _build_date_delta( - exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None -) -> t.Callable[[t.List], E]: - def _builder(args: t.List) -> E: - unit = seq_get(args, 0) - if unit and unit_mapping: - unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) - - start_date = seq_get(args, 1) - if start_date and start_date.is_number: - # Numeric types are valid DATETIME values - if start_date.is_int: - adds = DEFAULT_START_DATE + datetime.timedelta(days=start_date.to_py()) - start_date = exp.Literal.string(adds.strftime("%F")) - else: - # We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs. - # This is not a problem when generating T-SQL code, it is when transpiling to other dialects. - return exp_class(this=seq_get(args, 2), expression=start_date, unit=unit) - - return exp_class( - this=exp.TimeStrToTime(this=seq_get(args, 2)), - expression=exp.TimeStrToTime(this=start_date), - unit=unit, - ) - - return _builder - - -def qualify_derived_table_outputs(expression: exp.Expression) -> exp.Expression: - """Ensures all (unnamed) output columns are aliased for CTEs and Subqueries.""" - alias = expression.args.get("alias") - - if ( - isinstance(expression, (exp.CTE, exp.Subquery)) - and isinstance(alias, exp.TableAlias) - and not alias.columns - ): - from sqlglot.optimizer.qualify_columns import qualify_outputs - - # We keep track of the unaliased column projection indexes instead of the expressions - # themselves, because the latter are going to be replaced by new nodes when the aliases - # are added and hence we won't be able to reach these newly added Alias parents - query = expression.this - unaliased_column_indexes = ( - i for i, c in enumerate(query.selects) if isinstance(c, exp.Column) and not c.alias - ) - - qualify_outputs(query) - - # Preserve the quoting information of columns for newly added Alias nodes - query_selects = query.selects - for select_index in unaliased_column_indexes: - alias = query_selects[select_index] - column = alias.this - if isinstance(column.this, exp.Identifier): - alias.args["alias"].set("quoted", column.this.quoted) - - return expression - - -# https://learn.microsoft.com/en-us/sql/t-sql/functions/datetimefromparts-transact-sql?view=sql-server-ver16#syntax -def _build_datetimefromparts(args: t.List) -> exp.TimestampFromParts: - return exp.TimestampFromParts( - year=seq_get(args, 0), - month=seq_get(args, 1), - day=seq_get(args, 2), - hour=seq_get(args, 3), - min=seq_get(args, 4), - sec=seq_get(args, 5), - milli=seq_get(args, 6), - ) - - -# https://learn.microsoft.com/en-us/sql/t-sql/functions/timefromparts-transact-sql?view=sql-server-ver16#syntax -def _build_timefromparts(args: t.List) -> exp.TimeFromParts: - return exp.TimeFromParts( - hour=seq_get(args, 0), - min=seq_get(args, 1), - sec=seq_get(args, 2), - fractions=seq_get(args, 3), - precision=seq_get(args, 4), - ) - - -def _build_with_arg_as_text( - klass: t.Type[exp.Expression], -) -> t.Callable[[t.List[exp.Expression]], exp.Expression]: - def _parse(args: t.List[exp.Expression]) -> exp.Expression: - this = seq_get(args, 0) - - if this and not this.is_string: - this = exp.cast(this, exp.DataType.Type.TEXT) - - expression = seq_get(args, 1) - kwargs = {"this": this} - - if expression: - kwargs["expression"] = expression - - return klass(**kwargs) - - return _parse - - -# https://learn.microsoft.com/en-us/sql/t-sql/functions/parsename-transact-sql?view=sql-server-ver16 -def _build_parsename(args: t.List) -> exp.SplitPart | exp.Anonymous: - # PARSENAME(...) will be stored into exp.SplitPart if: - # - All args are literals - # - The part index (2nd arg) is <= 4 (max valid value, otherwise TSQL returns NULL) - if len(args) == 2 and all(isinstance(arg, exp.Literal) for arg in args): - this = args[0] - part_index = args[1] - split_count = len(this.name.split(".")) - if split_count <= 4: - return exp.SplitPart( - this=this, - delimiter=exp.Literal.string("."), - part_index=exp.Literal.number(split_count + 1 - part_index.to_py()), - ) - - return exp.Anonymous(this="PARSENAME", expressions=args) - - -def _build_json_query(args: t.List, dialect: Dialect) -> exp.JSONExtract: - if len(args) == 1: - # The default value for path is '$'. As a result, if you don't provide a - # value for path, JSON_QUERY returns the input expression. - args.append(exp.Literal.string("$")) - - return parser.build_extract_json_with_path(exp.JSONExtract)(args, dialect) - - -def _json_extract_sql( - self: TSQL.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar -) -> str: - json_query = self.func("JSON_QUERY", expression.this, expression.expression) - json_value = self.func("JSON_VALUE", expression.this, expression.expression) - return self.func("ISNULL", json_query, json_value) - - -def _timestrtotime_sql(self: TSQL.Generator, expression: exp.TimeStrToTime): - sql = timestrtotime_sql(self, expression) - if expression.args.get("zone"): - # If there is a timezone, produce an expression like: - # CAST('2020-01-01 12:13:14-08:00' AS DATETIMEOFFSET) AT TIME ZONE 'UTC' - # If you dont have AT TIME ZONE 'UTC', wrapping that expression in another cast back to DATETIME2 just drops the timezone information - return self.sql(exp.AtTimeZone(this=sql, zone=exp.Literal.string("UTC"))) - return sql - - -def _build_datetrunc(args: t.List) -> exp.TimestampTrunc: - unit = seq_get(args, 0) - this = seq_get(args, 1) - - if this and this.is_string: - this = exp.cast(this, exp.DataType.Type.DATETIME2) - - return exp.TimestampTrunc(this=this, unit=unit) - - -class TSQL(Dialect): - SUPPORTS_SEMI_ANTI_JOIN = False - LOG_BASE_FIRST = False - TYPED_DIVISION = True - CONCAT_COALESCE = True - NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE - ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN = False - - TIME_FORMAT = "'yyyy-mm-dd hh:mm:ss'" - - TIME_MAPPING = { - "year": "%Y", - "dayofyear": "%j", - "day": "%d", - "dy": "%d", - "y": "%Y", - "week": "%W", - "ww": "%W", - "wk": "%W", - "hour": "%h", - "hh": "%I", - "minute": "%M", - "mi": "%M", - "n": "%M", - "second": "%S", - "ss": "%S", - "s": "%-S", - "millisecond": "%f", - "ms": "%f", - "weekday": "%w", - "dw": "%w", - "month": "%m", - "mm": "%M", - "m": "%-M", - "Y": "%Y", - "YYYY": "%Y", - "YY": "%y", - "MMMM": "%B", - "MMM": "%b", - "MM": "%m", - "M": "%-m", - "dddd": "%A", - "dd": "%d", - "d": "%-d", - "HH": "%H", - "H": "%-H", - "h": "%-I", - "ffffff": "%f", - "yyyy": "%Y", - "yy": "%y", - } - - CONVERT_FORMAT_MAPPING = { - "0": "%b %d %Y %-I:%M%p", - "1": "%m/%d/%y", - "2": "%y.%m.%d", - "3": "%d/%m/%y", - "4": "%d.%m.%y", - "5": "%d-%m-%y", - "6": "%d %b %y", - "7": "%b %d, %y", - "8": "%H:%M:%S", - "9": "%b %d %Y %-I:%M:%S:%f%p", - "10": "mm-dd-yy", - "11": "yy/mm/dd", - "12": "yymmdd", - "13": "%d %b %Y %H:%M:ss:%f", - "14": "%H:%M:%S:%f", - "20": "%Y-%m-%d %H:%M:%S", - "21": "%Y-%m-%d %H:%M:%S.%f", - "22": "%m/%d/%y %-I:%M:%S %p", - "23": "%Y-%m-%d", - "24": "%H:%M:%S", - "25": "%Y-%m-%d %H:%M:%S.%f", - "100": "%b %d %Y %-I:%M%p", - "101": "%m/%d/%Y", - "102": "%Y.%m.%d", - "103": "%d/%m/%Y", - "104": "%d.%m.%Y", - "105": "%d-%m-%Y", - "106": "%d %b %Y", - "107": "%b %d, %Y", - "108": "%H:%M:%S", - "109": "%b %d %Y %-I:%M:%S:%f%p", - "110": "%m-%d-%Y", - "111": "%Y/%m/%d", - "112": "%Y%m%d", - "113": "%d %b %Y %H:%M:%S:%f", - "114": "%H:%M:%S:%f", - "120": "%Y-%m-%d %H:%M:%S", - "121": "%Y-%m-%d %H:%M:%S.%f", - "126": "%Y-%m-%dT%H:%M:%S.%f", - } - - FORMAT_TIME_MAPPING = { - "y": "%B %Y", - "d": "%m/%d/%Y", - "H": "%-H", - "h": "%-I", - "s": "%Y-%m-%d %H:%M:%S", - "D": "%A,%B,%Y", - "f": "%A,%B,%Y %-I:%M %p", - "F": "%A,%B,%Y %-I:%M:%S %p", - "g": "%m/%d/%Y %-I:%M %p", - "G": "%m/%d/%Y %-I:%M:%S %p", - "M": "%B %-d", - "m": "%B %-d", - "O": "%Y-%m-%dT%H:%M:%S", - "u": "%Y-%M-%D %H:%M:%S%z", - "U": "%A, %B %D, %Y %H:%M:%S%z", - "T": "%-I:%M:%S %p", - "t": "%-I:%M", - "Y": "%a %Y", - } - - class Tokenizer(tokens.Tokenizer): - IDENTIFIERS = [("[", "]"), '"'] - QUOTES = ["'", '"'] - HEX_STRINGS = [("0x", ""), ("0X", "")] - VAR_SINGLE_TOKENS = {"@", "$", "#"} - - KEYWORDS = { - **tokens.Tokenizer.KEYWORDS, - "CLUSTERED INDEX": TokenType.INDEX, - "DATETIME2": TokenType.DATETIME2, - "DATETIMEOFFSET": TokenType.TIMESTAMPTZ, - "DECLARE": TokenType.DECLARE, - "EXEC": TokenType.COMMAND, - "FOR SYSTEM_TIME": TokenType.TIMESTAMP_SNAPSHOT, - "GO": TokenType.COMMAND, - "IMAGE": TokenType.IMAGE, - "MONEY": TokenType.MONEY, - "NONCLUSTERED INDEX": TokenType.INDEX, - "NTEXT": TokenType.TEXT, - "OPTION": TokenType.OPTION, - "OUTPUT": TokenType.RETURNING, - "PRINT": TokenType.COMMAND, - "PROC": TokenType.PROCEDURE, - "REAL": TokenType.FLOAT, - "ROWVERSION": TokenType.ROWVERSION, - "SMALLDATETIME": TokenType.SMALLDATETIME, - "SMALLMONEY": TokenType.SMALLMONEY, - "SQL_VARIANT": TokenType.VARIANT, - "SYSTEM_USER": TokenType.CURRENT_USER, - "TOP": TokenType.TOP, - "TIMESTAMP": TokenType.ROWVERSION, - "TINYINT": TokenType.UTINYINT, - "UNIQUEIDENTIFIER": TokenType.UUID, - "UPDATE STATISTICS": TokenType.COMMAND, - "XML": TokenType.XML, - } - KEYWORDS.pop("/*+") - - COMMANDS = {*tokens.Tokenizer.COMMANDS, TokenType.END} - - class Parser(parser.Parser): - SET_REQUIRES_ASSIGNMENT_DELIMITER = False - LOG_DEFAULTS_TO_LN = True - STRING_ALIASES = True - NO_PAREN_IF_COMMANDS = False - - QUERY_MODIFIER_PARSERS = { - **parser.Parser.QUERY_MODIFIER_PARSERS, - TokenType.OPTION: lambda self: ("options", self._parse_options()), - TokenType.FOR: lambda self: ("for", self._parse_for()), - } - - # T-SQL does not allow BEGIN to be used as an identifier - ID_VAR_TOKENS = parser.Parser.ID_VAR_TOKENS - {TokenType.BEGIN} - ALIAS_TOKENS = parser.Parser.ALIAS_TOKENS - {TokenType.BEGIN} - TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {TokenType.BEGIN} - COMMENT_TABLE_ALIAS_TOKENS = parser.Parser.COMMENT_TABLE_ALIAS_TOKENS - {TokenType.BEGIN} - UPDATE_ALIAS_TOKENS = parser.Parser.UPDATE_ALIAS_TOKENS - {TokenType.BEGIN} - - FUNCTIONS = { - **parser.Parser.FUNCTIONS, - "CHARINDEX": lambda args: exp.StrPosition( - this=seq_get(args, 1), - substr=seq_get(args, 0), - position=seq_get(args, 2), - ), - "COUNT": lambda args: exp.Count( - this=seq_get(args, 0), expressions=args[1:], big_int=False - ), - "COUNT_BIG": lambda args: exp.Count( - this=seq_get(args, 0), expressions=args[1:], big_int=True - ), - "DATEADD": build_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), - "DATEDIFF": _build_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), - "DATENAME": _build_formatted_time(exp.TimeToStr, full_format_mapping=True), - "DATEPART": _build_formatted_time(exp.TimeToStr), - "DATETIMEFROMPARTS": _build_datetimefromparts, - "EOMONTH": _build_eomonth, - "FORMAT": _build_format, - "GETDATE": exp.CurrentTimestamp.from_arg_list, - "HASHBYTES": _build_hashbytes, - "ISNULL": lambda args: build_coalesce(args=args, is_null=True), - "JSON_QUERY": _build_json_query, - "JSON_VALUE": parser.build_extract_json_with_path(exp.JSONExtractScalar), - "LEN": _build_with_arg_as_text(exp.Length), - "LEFT": _build_with_arg_as_text(exp.Left), - "NEWID": exp.Uuid.from_arg_list, - "RIGHT": _build_with_arg_as_text(exp.Right), - "PARSENAME": _build_parsename, - "REPLICATE": exp.Repeat.from_arg_list, - "SCHEMA_NAME": exp.CurrentSchema.from_arg_list, - "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), - "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, - "SUSER_NAME": exp.CurrentUser.from_arg_list, - "SUSER_SNAME": exp.CurrentUser.from_arg_list, - "SYSTEM_USER": exp.CurrentUser.from_arg_list, - "TIMEFROMPARTS": _build_timefromparts, - "DATETRUNC": _build_datetrunc, - } - - JOIN_HINTS = {"LOOP", "HASH", "MERGE", "REMOTE"} - - PROCEDURE_OPTIONS = dict.fromkeys( - ("ENCRYPTION", "RECOMPILE", "SCHEMABINDING", "NATIVE_COMPILATION", "EXECUTE"), tuple() - ) - - COLUMN_DEFINITION_MODES = {"OUT", "OUTPUT", "READ_ONLY"} - - RETURNS_TABLE_TOKENS = parser.Parser.ID_VAR_TOKENS - { - TokenType.TABLE, - *parser.Parser.TYPE_TOKENS, - } - - STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, - TokenType.DECLARE: lambda self: self._parse_declare(), - } - - RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, - TokenType.DCOLON: lambda self, this: self.expression( - exp.ScopeResolution, - this=this, - expression=self._parse_function() or self._parse_var(any_token=True), - ), - } - - NO_PAREN_FUNCTION_PARSERS = { - **parser.Parser.NO_PAREN_FUNCTION_PARSERS, - "NEXT": lambda self: self._parse_next_value_for(), - } - - # The DCOLON (::) operator serves as a scope resolution (exp.ScopeResolution) operator in T-SQL - COLUMN_OPERATORS = { - **parser.Parser.COLUMN_OPERATORS, - TokenType.DCOLON: lambda self, this, to: self.expression(exp.Cast, this=this, to=to) - if isinstance(to, exp.DataType) and to.this != exp.DataType.Type.USERDEFINED - else self.expression(exp.ScopeResolution, this=this, expression=to), - } - - def _parse_alter_table_set(self) -> exp.AlterSet: - return self._parse_wrapped(super()._parse_alter_table_set) - - def _parse_wrapped_select(self, table: bool = False) -> t.Optional[exp.Expression]: - if self._match(TokenType.MERGE): - comments = self._prev_comments - merge = self._parse_merge() - merge.add_comments(comments, prepend=True) - return merge - - return super()._parse_wrapped_select(table=table) - - def _parse_dcolon(self) -> t.Optional[exp.Expression]: - # We want to use _parse_types() if the first token after :: is a known type, - # otherwise we could parse something like x::varchar(max) into a function - if self._match_set(self.TYPE_TOKENS, advance=False): - return self._parse_types() - - return self._parse_function() or self._parse_types() - - def _parse_options(self) -> t.Optional[t.List[exp.Expression]]: - if not self._match(TokenType.OPTION): - return None - - def _parse_option() -> t.Optional[exp.Expression]: - option = self._parse_var_from_options(OPTIONS) - if not option: - return None - - self._match(TokenType.EQ) - return self.expression( - exp.QueryOption, this=option, expression=self._parse_primary_or_var() - ) - - return self._parse_wrapped_csv(_parse_option) - - def _parse_xml_key_value_option(self) -> exp.XMLKeyValueOption: - this = self._parse_primary_or_var() - if self._match(TokenType.L_PAREN, advance=False): - expression = self._parse_wrapped(self._parse_string) - else: - expression = None - - return exp.XMLKeyValueOption(this=this, expression=expression) - - def _parse_for(self) -> t.Optional[t.List[exp.Expression]]: - if not self._match_pair(TokenType.FOR, TokenType.XML): - return None - - def _parse_for_xml() -> t.Optional[exp.Expression]: - return self.expression( - exp.QueryOption, - this=self._parse_var_from_options(XML_OPTIONS, raise_unmatched=False) - or self._parse_xml_key_value_option(), - ) - - return self._parse_csv(_parse_for_xml) - - def _parse_projections(self) -> t.List[exp.Expression]: - """ - T-SQL supports the syntax alias = expression in the SELECT's projection list, - so we transform all parsed Selects to convert their EQ projections into Aliases. - - See: https://learn.microsoft.com/en-us/sql/t-sql/queries/select-clause-transact-sql?view=sql-server-ver16#syntax - """ - return [ - ( - exp.alias_(projection.expression, projection.this.this, copy=False) - if isinstance(projection, exp.EQ) and isinstance(projection.this, exp.Column) - else projection - ) - for projection in super()._parse_projections() - ] - - def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: - """Applies to SQL Server and Azure SQL Database - COMMIT [ { TRAN | TRANSACTION } - [ transaction_name | @tran_name_variable ] ] - [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] - - ROLLBACK { TRAN | TRANSACTION } - [ transaction_name | @tran_name_variable - | savepoint_name | @savepoint_variable ] - """ - rollback = self._prev.token_type == TokenType.ROLLBACK - - self._match_texts(("TRAN", "TRANSACTION")) - this = self._parse_id_var() - - if rollback: - return self.expression(exp.Rollback, this=this) - - durability = None - if self._match_pair(TokenType.WITH, TokenType.L_PAREN): - self._match_text_seq("DELAYED_DURABILITY") - self._match(TokenType.EQ) - - if self._match_text_seq("OFF"): - durability = False - else: - self._match(TokenType.ON) - durability = True - - self._match_r_paren() - - return self.expression(exp.Commit, this=this, durability=durability) - - def _parse_transaction(self) -> exp.Transaction | exp.Command: - """Applies to SQL Server and Azure SQL Database - BEGIN { TRAN | TRANSACTION } - [ { transaction_name | @tran_name_variable } - [ WITH MARK [ 'description' ] ] - ] - """ - if self._match_texts(("TRAN", "TRANSACTION")): - transaction = self.expression(exp.Transaction, this=self._parse_id_var()) - if self._match_text_seq("WITH", "MARK"): - transaction.set("mark", self._parse_string()) - - return transaction - - return self._parse_as_command(self._prev) - - def _parse_returns(self) -> exp.ReturnsProperty: - table = self._parse_id_var(any_token=False, tokens=self.RETURNS_TABLE_TOKENS) - returns = super()._parse_returns() - returns.set("table", table) - return returns - - def _parse_convert( - self, strict: bool, safe: t.Optional[bool] = None - ) -> t.Optional[exp.Expression]: - this = self._parse_types() - self._match(TokenType.COMMA) - args = [this, *self._parse_csv(self._parse_assignment)] - convert = exp.Convert.from_arg_list(args) - convert.set("safe", safe) - convert.set("strict", strict) - return convert - - def _parse_column_def( - self, this: t.Optional[exp.Expression], computed_column: bool = True - ) -> t.Optional[exp.Expression]: - this = super()._parse_column_def(this=this, computed_column=computed_column) - if not this: - return None - if self._match(TokenType.EQ): - this.set("default", self._parse_disjunction()) - if self._match_texts(self.COLUMN_DEFINITION_MODES): - this.set("output", self._prev.text) - return this - - def _parse_user_defined_function( - self, kind: t.Optional[TokenType] = None - ) -> t.Optional[exp.Expression]: - this = super()._parse_user_defined_function(kind=kind) - - if ( - kind == TokenType.FUNCTION - or isinstance(this, exp.UserDefinedFunction) - or self._match(TokenType.ALIAS, advance=False) - ): - return this - - if not self._match(TokenType.WITH, advance=False): - expressions = self._parse_csv(self._parse_function_parameter) - else: - expressions = None - - return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) - - def _parse_into(self) -> t.Optional[exp.Into]: - into = super()._parse_into() - - table = isinstance(into, exp.Into) and into.find(exp.Table) - if isinstance(table, exp.Table): - table_identifier = table.this - if table_identifier.args.get("temporary"): - # Promote the temporary property from the Identifier to the Into expression - t.cast(exp.Into, into).set("temporary", True) - - return into - - def _parse_id_var( - self, - any_token: bool = True, - tokens: t.Optional[t.Collection[TokenType]] = None, - ) -> t.Optional[exp.Expression]: - is_temporary = self._match(TokenType.HASH) - is_global = is_temporary and self._match(TokenType.HASH) - - this = super()._parse_id_var(any_token=any_token, tokens=tokens) - if this: - if is_global: - this.set("global", True) - elif is_temporary: - this.set("temporary", True) - - return this - - def _parse_create(self) -> exp.Create | exp.Command: - create = super()._parse_create() - - if isinstance(create, exp.Create): - table = create.this.this if isinstance(create.this, exp.Schema) else create.this - if isinstance(table, exp.Table) and table.this and table.this.args.get("temporary"): - if not create.args.get("properties"): - create.set("properties", exp.Properties(expressions=[])) - - create.args["properties"].append("expressions", exp.TemporaryProperty()) - - return create - - def _parse_if(self) -> t.Optional[exp.Expression]: - index = self._index - - if self._match_text_seq("OBJECT_ID"): - self._parse_wrapped_csv(self._parse_string) - if self._match_text_seq("IS", "NOT", "NULL") and self._match(TokenType.DROP): - return self._parse_drop(exists=True) - self._retreat(index) - - return super()._parse_if() - - def _parse_unique(self) -> exp.UniqueColumnConstraint: - if self._match_texts(("CLUSTERED", "NONCLUSTERED")): - this = self.CONSTRAINT_PARSERS[self._prev.text.upper()](self) - else: - this = self._parse_schema(self._parse_id_var(any_token=False)) - - return self.expression(exp.UniqueColumnConstraint, this=this) - - def _parse_partition(self) -> t.Optional[exp.Partition]: - if not self._match_text_seq("WITH", "(", "PARTITIONS"): - return None - - def parse_range(): - low = self._parse_bitwise() - high = self._parse_bitwise() if self._match_text_seq("TO") else None - - return ( - self.expression(exp.PartitionRange, this=low, expression=high) if high else low - ) - - partition = self.expression( - exp.Partition, expressions=self._parse_wrapped_csv(parse_range) - ) - - self._match_r_paren() - - return partition - - def _parse_declare(self) -> exp.Declare | exp.Command: - index = self._index - expressions = self._try_parse(partial(self._parse_csv, self._parse_declareitem)) - - if not expressions or self._curr: - self._retreat(index) - return self._parse_as_command(self._prev) - - return self.expression(exp.Declare, expressions=expressions) - - def _parse_declareitem(self) -> t.Optional[exp.DeclareItem]: - var = self._parse_id_var() - if not var: - return None - - value = None - self._match(TokenType.ALIAS) - if self._match(TokenType.TABLE): - data_type = self._parse_schema() - else: - data_type = self._parse_types() - if self._match(TokenType.EQ): - value = self._parse_bitwise() - - return self.expression(exp.DeclareItem, this=var, kind=data_type, default=value) - - def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]: - expression = super()._parse_alter_table_alter() - - if expression is not None: - collation = expression.args.get("collate") - if isinstance(collation, exp.Column) and isinstance(collation.this, exp.Identifier): - identifier = collation.this - collation.set("this", exp.Var(this=identifier.name)) - - return expression - - class Generator(generator.Generator): - LIMIT_IS_TOP = True - QUERY_HINTS = False - RETURNING_END = False - NVL2_SUPPORTED = False - ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = False - LIMIT_FETCH = "FETCH" - COMPUTED_COLUMN_WITH_TYPE = False - CTE_RECURSIVE_KEYWORD_REQUIRED = False - ENSURE_BOOLS = True - NULL_ORDERING_SUPPORTED = None - SUPPORTS_SINGLE_ARG_CONCAT = False - TABLESAMPLE_SEED_KEYWORD = "REPEATABLE" - SUPPORTS_SELECT_INTO = True - JSON_PATH_BRACKETED_KEY_SUPPORTED = False - SUPPORTS_TO_NUMBER = False - SET_OP_MODIFIERS = False - COPY_PARAMS_EQ_REQUIRED = True - PARSE_JSON_NAME = None - EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE = False - ALTER_SET_WRAPPED = True - ALTER_SET_TYPE = "" - - EXPRESSIONS_WITHOUT_NESTED_CTES = { - exp.Create, - exp.Delete, - exp.Insert, - exp.Intersect, - exp.Except, - exp.Merge, - exp.Select, - exp.Subquery, - exp.Union, - exp.Update, - } - - SUPPORTED_JSON_PATH_PARTS = { - exp.JSONPathKey, - exp.JSONPathRoot, - exp.JSONPathSubscript, - } - - TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, - exp.DataType.Type.BOOLEAN: "BIT", - exp.DataType.Type.DATETIME2: "DATETIME2", - exp.DataType.Type.DECIMAL: "NUMERIC", - exp.DataType.Type.DOUBLE: "FLOAT", - exp.DataType.Type.INT: "INTEGER", - exp.DataType.Type.ROWVERSION: "ROWVERSION", - exp.DataType.Type.TEXT: "VARCHAR(MAX)", - exp.DataType.Type.TIMESTAMP: "DATETIME2", - exp.DataType.Type.TIMESTAMPNTZ: "DATETIME2", - exp.DataType.Type.TIMESTAMPTZ: "DATETIMEOFFSET", - exp.DataType.Type.SMALLDATETIME: "SMALLDATETIME", - exp.DataType.Type.UTINYINT: "TINYINT", - exp.DataType.Type.VARIANT: "SQL_VARIANT", - exp.DataType.Type.UUID: "UNIQUEIDENTIFIER", - } - - TYPE_MAPPING.pop(exp.DataType.Type.NCHAR) - TYPE_MAPPING.pop(exp.DataType.Type.NVARCHAR) - - TRANSFORMS = { - **generator.Generator.TRANSFORMS, - exp.AnyValue: any_value_to_max_sql, - exp.ArrayToString: rename_func("STRING_AGG"), - exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY", - exp.Chr: rename_func("CHAR"), - exp.DateAdd: date_delta_sql("DATEADD"), - exp.DateDiff: date_delta_sql("DATEDIFF"), - exp.CTE: transforms.preprocess([qualify_derived_table_outputs]), - exp.CurrentDate: rename_func("GETDATE"), - exp.CurrentTimestamp: rename_func("GETDATE"), - exp.DateStrToDate: datestrtodate_sql, - exp.Extract: rename_func("DATEPART"), - exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql, - exp.GroupConcat: _string_agg_sql, - exp.If: rename_func("IIF"), - exp.JSONExtract: _json_extract_sql, - exp.JSONExtractScalar: _json_extract_sql, - exp.LastDay: lambda self, e: self.func("EOMONTH", e.this), - exp.Ln: rename_func("LOG"), - exp.Max: max_or_greatest, - exp.MD5: lambda self, e: self.func("HASHBYTES", exp.Literal.string("MD5"), e.this), - exp.Min: min_or_least, - exp.NumberToStr: _format_sql, - exp.Repeat: rename_func("REPLICATE"), - exp.CurrentSchema: rename_func("SCHEMA_NAME"), - exp.Select: transforms.preprocess( - [ - transforms.eliminate_distinct_on, - transforms.eliminate_semi_and_anti_joins, - transforms.eliminate_qualify, - transforms.unnest_generate_date_array_using_recursive_cte, - ] - ), - exp.Stddev: rename_func("STDEV"), - exp.StrPosition: lambda self, e: strposition_sql( - self, e, func_name="CHARINDEX", supports_position=True - ), - exp.Subquery: transforms.preprocess([qualify_derived_table_outputs]), - exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), - exp.SHA2: lambda self, e: self.func( - "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this - ), - exp.TemporaryProperty: lambda self, e: "", - exp.TimeStrToTime: _timestrtotime_sql, - exp.TimeToStr: _format_sql, - exp.Trim: trim_sql, - exp.TsOrDsAdd: date_delta_sql("DATEADD", cast=True), - exp.TsOrDsDiff: date_delta_sql("DATEDIFF"), - exp.TimestampTrunc: lambda self, e: self.func("DATETRUNC", e.unit, e.this), - exp.Uuid: lambda *_: "NEWID()", - exp.DateFromParts: rename_func("DATEFROMPARTS"), - } - - TRANSFORMS.pop(exp.ReturnsProperty) - - PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, - exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, - } - - def scope_resolution(self, rhs: str, scope_name: str) -> str: - return f"{scope_name}::{rhs}" - - def select_sql(self, expression: exp.Select) -> str: - limit = expression.args.get("limit") - offset = expression.args.get("offset") - - if isinstance(limit, exp.Fetch) and not offset: - # Dialects like Oracle can FETCH directly from a row set but - # T-SQL requires an ORDER BY + OFFSET clause in order to FETCH - offset = exp.Offset(expression=exp.Literal.number(0)) - expression.set("offset", offset) - - if offset: - if not expression.args.get("order"): - # ORDER BY is required in order to use OFFSET in a query, so we use - # a noop order by, since we don't really care about the order. - # See: https://www.microsoftpressstore.com/articles/article.aspx?p=2314819 - expression.order_by(exp.select(exp.null()).subquery(), copy=False) - - if isinstance(limit, exp.Limit): - # TOP and OFFSET can't be combined, we need use FETCH instead of TOP - # we replace here because otherwise TOP would be generated in select_sql - limit.replace(exp.Fetch(direction="FIRST", count=limit.expression)) - - return super().select_sql(expression) - - def convert_sql(self, expression: exp.Convert) -> str: - name = "TRY_CONVERT" if expression.args.get("safe") else "CONVERT" - return self.func( - name, expression.this, expression.expression, expression.args.get("style") - ) - - def queryoption_sql(self, expression: exp.QueryOption) -> str: - option = self.sql(expression, "this") - value = self.sql(expression, "expression") - if value: - optional_equal_sign = "= " if option in OPTIONS_THAT_REQUIRE_EQUAL else "" - return f"{option} {optional_equal_sign}{value}" - return option - - def lateral_op(self, expression: exp.Lateral) -> str: - cross_apply = expression.args.get("cross_apply") - if cross_apply is True: - return "CROSS APPLY" - if cross_apply is False: - return "OUTER APPLY" - - # TODO: perhaps we can check if the parent is a Join and transpile it appropriately - self.unsupported("LATERAL clause is not supported.") - return "LATERAL" - - def splitpart_sql(self: TSQL.Generator, expression: exp.SplitPart) -> str: - this = expression.this - split_count = len(this.name.split(".")) - delimiter = expression.args.get("delimiter") - part_index = expression.args.get("part_index") - - if ( - not all(isinstance(arg, exp.Literal) for arg in (this, delimiter, part_index)) - or (delimiter and delimiter.name != ".") - or not part_index - or split_count > 4 - ): - self.unsupported( - "SPLIT_PART can be transpiled to PARSENAME only for '.' delimiter and literal values" - ) - return "" - - return self.func( - "PARSENAME", this, exp.Literal.number(split_count + 1 - part_index.to_py()) - ) - - def timefromparts_sql(self, expression: exp.TimeFromParts) -> str: - nano = expression.args.get("nano") - if nano is not None: - nano.pop() - self.unsupported("Specifying nanoseconds is not supported in TIMEFROMPARTS.") - - if expression.args.get("fractions") is None: - expression.set("fractions", exp.Literal.number(0)) - if expression.args.get("precision") is None: - expression.set("precision", exp.Literal.number(0)) - - return rename_func("TIMEFROMPARTS")(self, expression) - - def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str: - zone = expression.args.get("zone") - if zone is not None: - zone.pop() - self.unsupported("Time zone is not supported in DATETIMEFROMPARTS.") - - nano = expression.args.get("nano") - if nano is not None: - nano.pop() - self.unsupported("Specifying nanoseconds is not supported in DATETIMEFROMPARTS.") - - if expression.args.get("milli") is None: - expression.set("milli", exp.Literal.number(0)) - - return rename_func("DATETIMEFROMPARTS")(self, expression) - - def setitem_sql(self, expression: exp.SetItem) -> str: - this = expression.this - if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter): - # T-SQL does not use '=' in SET command, except when the LHS is a variable. - return f"{self.sql(this.left)} {self.sql(this.right)}" - - return super().setitem_sql(expression) - - def boolean_sql(self, expression: exp.Boolean) -> str: - if type(expression.parent) in BIT_TYPES or isinstance( - expression.find_ancestor(exp.Values, exp.Select), exp.Values - ): - return "1" if expression.this else "0" - - return "(1 = 1)" if expression.this else "(1 = 0)" - - def is_sql(self, expression: exp.Is) -> str: - if isinstance(expression.expression, exp.Boolean): - return self.binary(expression, "=") - return self.binary(expression, "IS") - - def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: - sql = self.sql(expression, "this") - properties = expression.args.get("properties") - - if sql[:1] != "#" and any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) - ): - sql = f"[#{sql[1:]}" if sql.startswith("[") else f"#{sql}" - - return sql - - def create_sql(self, expression: exp.Create) -> str: - kind = expression.kind - exists = expression.args.pop("exists", None) - - like_property = expression.find(exp.LikeProperty) - if like_property: - ctas_expression = like_property.this - else: - ctas_expression = expression.expression - - if kind == "VIEW": - expression.this.set("catalog", None) - with_ = expression.args.get("with") - if ctas_expression and with_: - # We've already preprocessed the Create expression to bubble up any nested CTEs, - # but CREATE VIEW actually requires the WITH clause to come after it so we need - # to amend the AST by moving the CTEs to the CREATE VIEW statement's query. - ctas_expression.set("with", with_.pop()) - - sql = super().create_sql(expression) - - table = expression.find(exp.Table) - - # Convert CTAS statement to SELECT .. INTO .. - if kind == "TABLE" and ctas_expression: - if isinstance(ctas_expression, exp.UNWRAPPED_QUERIES): - ctas_expression = ctas_expression.subquery() - - properties = expression.args.get("properties") or exp.Properties() - is_temp = any(isinstance(p, exp.TemporaryProperty) for p in properties.expressions) - - select_into = exp.select("*").from_(exp.alias_(ctas_expression, "temp", table=True)) - select_into.set("into", exp.Into(this=table, temporary=is_temp)) - - if like_property: - select_into.limit(0, copy=False) - - sql = self.sql(select_into) - - if exists: - identifier = self.sql(exp.Literal.string(exp.table_name(table) if table else "")) - sql_with_ctes = self.prepend_ctes(expression, sql) - sql_literal = self.sql(exp.Literal.string(sql_with_ctes)) - if kind == "SCHEMA": - return f"""IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = {identifier}) EXEC({sql_literal})""" - elif kind == "TABLE": - assert table - where = exp.and_( - exp.column("table_name").eq(table.name), - exp.column("table_schema").eq(table.db) if table.db else None, - exp.column("table_catalog").eq(table.catalog) if table.catalog else None, - ) - return f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE {where}) EXEC({sql_literal})""" - elif kind == "INDEX": - index = self.sql(exp.Literal.string(expression.this.text("this"))) - return f"""IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id({identifier}) AND name = {index}) EXEC({sql_literal})""" - elif expression.args.get("replace"): - sql = sql.replace("CREATE OR REPLACE ", "CREATE OR ALTER ", 1) - - return self.prepend_ctes(expression, sql) - - @generator.unsupported_args("unlogged", "expressions") - def into_sql(self, expression: exp.Into) -> str: - if expression.args.get("temporary"): - # If the Into expression has a temporary property, push this down to the Identifier - table = expression.find(exp.Table) - if table and isinstance(table.this, exp.Identifier): - table.this.set("temporary", True) - - return f"{self.seg('INTO')} {self.sql(expression, 'this')}" - - def count_sql(self, expression: exp.Count) -> str: - func_name = "COUNT_BIG" if expression.args.get("big_int") else "COUNT" - return rename_func(func_name)(self, expression) - - def offset_sql(self, expression: exp.Offset) -> str: - return f"{super().offset_sql(expression)} ROWS" - - def version_sql(self, expression: exp.Version) -> str: - name = "SYSTEM_TIME" if expression.name == "TIMESTAMP" else expression.name - this = f"FOR {name}" - expr = expression.expression - kind = expression.text("kind") - if kind in ("FROM", "BETWEEN"): - args = expr.expressions - sep = "TO" if kind == "FROM" else "AND" - expr_sql = f"{self.sql(seq_get(args, 0))} {sep} {self.sql(seq_get(args, 1))}" - else: - expr_sql = self.sql(expr) - - expr_sql = f" {expr_sql}" if expr_sql else "" - return f"{this} {kind}{expr_sql}" - - def returnsproperty_sql(self, expression: exp.ReturnsProperty) -> str: - table = expression.args.get("table") - table = f"{table} " if table else "" - return f"RETURNS {table}{self.sql(expression, 'this')}" - - def returning_sql(self, expression: exp.Returning) -> str: - into = self.sql(expression, "into") - into = self.seg(f"INTO {into}") if into else "" - return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}" - - def transaction_sql(self, expression: exp.Transaction) -> str: - this = self.sql(expression, "this") - this = f" {this}" if this else "" - mark = self.sql(expression, "mark") - mark = f" WITH MARK {mark}" if mark else "" - return f"BEGIN TRANSACTION{this}{mark}" - - def commit_sql(self, expression: exp.Commit) -> str: - this = self.sql(expression, "this") - this = f" {this}" if this else "" - durability = expression.args.get("durability") - durability = ( - f" WITH (DELAYED_DURABILITY = {'ON' if durability else 'OFF'})" - if durability is not None - else "" - ) - return f"COMMIT TRANSACTION{this}{durability}" - - def rollback_sql(self, expression: exp.Rollback) -> str: - this = self.sql(expression, "this") - this = f" {this}" if this else "" - return f"ROLLBACK TRANSACTION{this}" - - def identifier_sql(self, expression: exp.Identifier) -> str: - identifier = super().identifier_sql(expression) - - if expression.args.get("global"): - identifier = f"##{identifier}" - elif expression.args.get("temporary"): - identifier = f"#{identifier}" - - return identifier - - def constraint_sql(self, expression: exp.Constraint) -> str: - this = self.sql(expression, "this") - expressions = self.expressions(expression, flat=True, sep=" ") - return f"CONSTRAINT {this} {expressions}" - - def length_sql(self, expression: exp.Length) -> str: - return self._uncast_text(expression, "LEN") - - def right_sql(self, expression: exp.Right) -> str: - return self._uncast_text(expression, "RIGHT") - - def left_sql(self, expression: exp.Left) -> str: - return self._uncast_text(expression, "LEFT") - - def _uncast_text(self, expression: exp.Expression, name: str) -> str: - this = expression.this - if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.TEXT): - this_sql = self.sql(this, "this") - else: - this_sql = self.sql(this) - expression_sql = self.sql(expression, "expression") - return self.func(name, this_sql, expression_sql if expression_sql else None) - - def partition_sql(self, expression: exp.Partition) -> str: - return f"WITH (PARTITIONS({self.expressions(expression, flat=True)}))" - - def alter_sql(self, expression: exp.Alter) -> str: - action = seq_get(expression.args.get("actions") or [], 0) - if isinstance(action, exp.AlterRename): - return f"EXEC sp_rename '{self.sql(expression.this)}', '{action.this.name}'" - return super().alter_sql(expression) - - def drop_sql(self, expression: exp.Drop) -> str: - if expression.args["kind"] == "VIEW": - expression.this.set("catalog", None) - return super().drop_sql(expression) - - def options_modifier(self, expression: exp.Expression) -> str: - options = self.expressions(expression, key="options") - return f" OPTION{self.wrap(options)}" if options else "" - - def dpipe_sql(self, expression: exp.DPipe) -> str: - return self.sql( - reduce(lambda x, y: exp.Add(this=x, expression=y), expression.flatten()) - ) - - def isascii_sql(self, expression: exp.IsAscii) -> str: - return f"(PATINDEX(CONVERT(VARCHAR(MAX), 0x255b5e002d7f5d25) COLLATE Latin1_General_BIN, {self.sql(expression.this)}) = 0)" - - def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: - this = super().columndef_sql(expression, sep) - default = self.sql(expression, "default") - default = f" = {default}" if default else "" - output = self.sql(expression, "output") - output = f" {output}" if output else "" - return f"{this}{default}{output}" - - def coalesce_sql(self, expression: exp.Coalesce) -> str: - func_name = "ISNULL" if expression.args.get("is_null") else "COALESCE" - return rename_func(func_name)(self, expression) diff --git a/altimate_packages/sqlglot/diff.py b/altimate_packages/sqlglot/diff.py deleted file mode 100644 index f7cdebd82..000000000 --- a/altimate_packages/sqlglot/diff.py +++ /dev/null @@ -1,456 +0,0 @@ -""" -.. include:: ../posts/sql_diff.md - ----- -""" - -from __future__ import annotations - -import typing as t -from collections import defaultdict -from dataclasses import dataclass -from heapq import heappop, heappush -from itertools import chain - -from sqlglot import Dialect, expressions as exp -from sqlglot.helper import seq_get - -if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import DialectType - - -@dataclass(frozen=True) -class Insert: - """Indicates that a new node has been inserted""" - - expression: exp.Expression - - -@dataclass(frozen=True) -class Remove: - """Indicates that an existing node has been removed""" - - expression: exp.Expression - - -@dataclass(frozen=True) -class Move: - """Indicates that an existing node's position within the tree has changed""" - - source: exp.Expression - target: exp.Expression - - -@dataclass(frozen=True) -class Update: - """Indicates that an existing node has been updated""" - - source: exp.Expression - target: exp.Expression - - -@dataclass(frozen=True) -class Keep: - """Indicates that an existing node hasn't been changed""" - - source: exp.Expression - target: exp.Expression - - -if t.TYPE_CHECKING: - from sqlglot._typing import T - - Edit = t.Union[Insert, Remove, Move, Update, Keep] - - -def diff( - source: exp.Expression, - target: exp.Expression, - matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None, - delta_only: bool = False, - **kwargs: t.Any, -) -> t.List[Edit]: - """ - Returns the list of changes between the source and the target expressions. - - Examples: - >>> diff(parse_one("a + b"), parse_one("a + c")) - [ - Remove(expression=(COLUMN this: (IDENTIFIER this: b, quoted: False))), - Insert(expression=(COLUMN this: (IDENTIFIER this: c, quoted: False))), - Keep( - source=(ADD this: ...), - target=(ADD this: ...) - ), - Keep( - source=(COLUMN this: (IDENTIFIER this: a, quoted: False)), - target=(COLUMN this: (IDENTIFIER this: a, quoted: False)) - ), - ] - - Args: - source: the source expression. - target: the target expression against which the diff should be calculated. - matchings: the list of pre-matched node pairs which is used to help the algorithm's - heuristics produce better results for subtrees that are known by a caller to be matching. - Note: expression references in this list must refer to the same node objects that are - referenced in the source / target trees. - delta_only: excludes all `Keep` nodes from the diff. - kwargs: additional arguments to pass to the ChangeDistiller instance. - - Returns: - the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the - target expression trees. This list represents a sequence of steps needed to transform the source - expression tree into the target one. - """ - matchings = matchings or [] - - def compute_node_mappings( - old_nodes: tuple[exp.Expression, ...], new_nodes: tuple[exp.Expression, ...] - ) -> t.Dict[int, exp.Expression]: - node_mapping = {} - for old_node, new_node in zip(reversed(old_nodes), reversed(new_nodes)): - new_node._hash = hash(new_node) - node_mapping[id(old_node)] = new_node - - return node_mapping - - # if the source and target have any shared objects, that means there's an issue with the ast - # the algorithm won't work because the parent / hierarchies will be inaccurate - source_nodes = tuple(source.walk()) - target_nodes = tuple(target.walk()) - source_ids = {id(n) for n in source_nodes} - target_ids = {id(n) for n in target_nodes} - - copy = ( - len(source_nodes) != len(source_ids) - or len(target_nodes) != len(target_ids) - or source_ids & target_ids - ) - - source_copy = source.copy() if copy else source - target_copy = target.copy() if copy else target - - try: - # We cache the hash of each new node here to speed up equality comparisons. If the input - # trees aren't copied, these hashes will be evicted before returning the edit script. - if copy and matchings: - source_mapping = compute_node_mappings(source_nodes, tuple(source_copy.walk())) - target_mapping = compute_node_mappings(target_nodes, tuple(target_copy.walk())) - matchings = [(source_mapping[id(s)], target_mapping[id(t)]) for s, t in matchings] - else: - for node in chain(reversed(source_nodes), reversed(target_nodes)): - node._hash = hash(node) - - edit_script = ChangeDistiller(**kwargs).diff( - source_copy, - target_copy, - matchings=matchings, - delta_only=delta_only, - ) - finally: - if not copy: - for node in chain(source_nodes, target_nodes): - node._hash = None - - return edit_script - - -# The expression types for which Update edits are allowed. -UPDATABLE_EXPRESSION_TYPES = ( - exp.Alias, - exp.Boolean, - exp.Column, - exp.DataType, - exp.Lambda, - exp.Literal, - exp.Table, - exp.Window, -) - -IGNORED_LEAF_EXPRESSION_TYPES = (exp.Identifier,) - - -class ChangeDistiller: - """ - The implementation of the Change Distiller algorithm described by Beat Fluri and Martin Pinzger in - their paper https://ieeexplore.ieee.org/document/4339230, which in turn is based on the algorithm by - Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf. - """ - - def __init__(self, f: float = 0.6, t: float = 0.6, dialect: DialectType = None) -> None: - self.f = f - self.t = t - self._sql_generator = Dialect.get_or_raise(dialect).generator() - - def diff( - self, - source: exp.Expression, - target: exp.Expression, - matchings: t.List[t.Tuple[exp.Expression, exp.Expression]] | None = None, - delta_only: bool = False, - ) -> t.List[Edit]: - matchings = matchings or [] - pre_matched_nodes = {id(s): id(t) for s, t in matchings} - - self._source = source - self._target = target - self._source_index = { - id(n): n for n in self._source.bfs() if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES) - } - self._target_index = { - id(n): n for n in self._target.bfs() if not isinstance(n, IGNORED_LEAF_EXPRESSION_TYPES) - } - self._unmatched_source_nodes = set(self._source_index) - set(pre_matched_nodes) - self._unmatched_target_nodes = set(self._target_index) - set(pre_matched_nodes.values()) - self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {} - - matching_set = self._compute_matching_set() | set(pre_matched_nodes.items()) - return self._generate_edit_script(dict(matching_set), delta_only) - - def _generate_edit_script(self, matchings: t.Dict[int, int], delta_only: bool) -> t.List[Edit]: - edit_script: t.List[Edit] = [] - for removed_node_id in self._unmatched_source_nodes: - edit_script.append(Remove(self._source_index[removed_node_id])) - for inserted_node_id in self._unmatched_target_nodes: - edit_script.append(Insert(self._target_index[inserted_node_id])) - for kept_source_node_id, kept_target_node_id in matchings.items(): - source_node = self._source_index[kept_source_node_id] - target_node = self._target_index[kept_target_node_id] - - identical_nodes = source_node == target_node - - if not isinstance(source_node, UPDATABLE_EXPRESSION_TYPES) or identical_nodes: - if identical_nodes: - source_parent = source_node.parent - target_parent = target_node.parent - - if ( - (source_parent and not target_parent) - or (not source_parent and target_parent) - or ( - source_parent - and target_parent - and matchings.get(id(source_parent)) != id(target_parent) - ) - ): - edit_script.append(Move(source=source_node, target=target_node)) - else: - edit_script.extend( - self._generate_move_edits(source_node, target_node, matchings) - ) - - source_non_expression_leaves = dict(_get_non_expression_leaves(source_node)) - target_non_expression_leaves = dict(_get_non_expression_leaves(target_node)) - - if source_non_expression_leaves != target_non_expression_leaves: - edit_script.append(Update(source_node, target_node)) - elif not delta_only: - edit_script.append(Keep(source_node, target_node)) - else: - edit_script.append(Update(source_node, target_node)) - - return edit_script - - def _generate_move_edits( - self, source: exp.Expression, target: exp.Expression, matchings: t.Dict[int, int] - ) -> t.List[Move]: - source_args = [id(e) for e in _expression_only_args(source)] - target_args = [id(e) for e in _expression_only_args(target)] - - args_lcs = set( - _lcs(source_args, target_args, lambda l, r: matchings.get(t.cast(int, l)) == r) - ) - - move_edits = [] - for a in source_args: - if a not in args_lcs and a not in self._unmatched_source_nodes: - move_edits.append( - Move(source=self._source_index[a], target=self._target_index[matchings[a]]) - ) - - return move_edits - - def _compute_matching_set(self) -> t.Set[t.Tuple[int, int]]: - leaves_matching_set = self._compute_leaf_matching_set() - matching_set = leaves_matching_set.copy() - - ordered_unmatched_source_nodes = { - id(n): None for n in self._source.bfs() if id(n) in self._unmatched_source_nodes - } - ordered_unmatched_target_nodes = { - id(n): None for n in self._target.bfs() if id(n) in self._unmatched_target_nodes - } - - for source_node_id in ordered_unmatched_source_nodes: - for target_node_id in ordered_unmatched_target_nodes: - source_node = self._source_index[source_node_id] - target_node = self._target_index[target_node_id] - if _is_same_type(source_node, target_node): - source_leaf_ids = {id(l) for l in _get_expression_leaves(source_node)} - target_leaf_ids = {id(l) for l in _get_expression_leaves(target_node)} - - max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) - if max_leaves_num: - common_leaves_num = sum( - 1 if s in source_leaf_ids and t in target_leaf_ids else 0 - for s, t in leaves_matching_set - ) - leaf_similarity_score = common_leaves_num / max_leaves_num - else: - leaf_similarity_score = 0.0 - - adjusted_t = ( - self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4 - ) - - if leaf_similarity_score >= 0.8 or ( - leaf_similarity_score >= adjusted_t - and self._dice_coefficient(source_node, target_node) >= self.f - ): - matching_set.add((source_node_id, target_node_id)) - self._unmatched_source_nodes.remove(source_node_id) - self._unmatched_target_nodes.remove(target_node_id) - ordered_unmatched_target_nodes.pop(target_node_id, None) - break - - return matching_set - - def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]: - candidate_matchings: t.List[t.Tuple[float, int, int, exp.Expression, exp.Expression]] = [] - source_expression_leaves = list(_get_expression_leaves(self._source)) - target_expression_leaves = list(_get_expression_leaves(self._target)) - for source_leaf in source_expression_leaves: - for target_leaf in target_expression_leaves: - if _is_same_type(source_leaf, target_leaf): - similarity_score = self._dice_coefficient(source_leaf, target_leaf) - if similarity_score >= self.f: - heappush( - candidate_matchings, - ( - -similarity_score, - -_parent_similarity_score(source_leaf, target_leaf), - len(candidate_matchings), - source_leaf, - target_leaf, - ), - ) - - # Pick best matchings based on the highest score - matching_set = set() - while candidate_matchings: - _, _, _, source_leaf, target_leaf = heappop(candidate_matchings) - if ( - id(source_leaf) in self._unmatched_source_nodes - and id(target_leaf) in self._unmatched_target_nodes - ): - matching_set.add((id(source_leaf), id(target_leaf))) - self._unmatched_source_nodes.remove(id(source_leaf)) - self._unmatched_target_nodes.remove(id(target_leaf)) - - return matching_set - - def _dice_coefficient(self, source: exp.Expression, target: exp.Expression) -> float: - source_histo = self._bigram_histo(source) - target_histo = self._bigram_histo(target) - - total_grams = sum(source_histo.values()) + sum(target_histo.values()) - if not total_grams: - return 1.0 if source == target else 0.0 - - overlap_len = 0 - overlapping_grams = set(source_histo) & set(target_histo) - for g in overlapping_grams: - overlap_len += min(source_histo[g], target_histo[g]) - - return 2 * overlap_len / total_grams - - def _bigram_histo(self, expression: exp.Expression) -> t.DefaultDict[str, int]: - if id(expression) in self._bigram_histo_cache: - return self._bigram_histo_cache[id(expression)] - - expression_str = self._sql_generator.generate(expression) - count = max(0, len(expression_str) - 1) - bigram_histo: t.DefaultDict[str, int] = defaultdict(int) - for i in range(count): - bigram_histo[expression_str[i : i + 2]] += 1 - - self._bigram_histo_cache[id(expression)] = bigram_histo - return bigram_histo - - -def _get_expression_leaves(expression: exp.Expression) -> t.Iterator[exp.Expression]: - has_child_exprs = False - - for node in expression.iter_expressions(): - if not isinstance(node, IGNORED_LEAF_EXPRESSION_TYPES): - has_child_exprs = True - yield from _get_expression_leaves(node) - - if not has_child_exprs: - yield expression - - -def _get_non_expression_leaves(expression: exp.Expression) -> t.Iterator[t.Tuple[str, t.Any]]: - for arg, value in expression.args.items(): - if isinstance(value, exp.Expression) or ( - isinstance(value, list) and isinstance(seq_get(value, 0), exp.Expression) - ): - continue - - yield (arg, value) - - -def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool: - if type(source) is type(target): - if isinstance(source, exp.Join): - return source.args.get("side") == target.args.get("side") - - if isinstance(source, exp.Anonymous): - return source.this == target.this - - return True - - return False - - -def _parent_similarity_score( - source: t.Optional[exp.Expression], target: t.Optional[exp.Expression] -) -> int: - if source is None or target is None or type(source) is not type(target): - return 0 - - return 1 + _parent_similarity_score(source.parent, target.parent) - - -def _expression_only_args(expression: exp.Expression) -> t.Iterator[exp.Expression]: - yield from ( - arg - for arg in expression.iter_expressions() - if not isinstance(arg, IGNORED_LEAF_EXPRESSION_TYPES) - ) - - -def _lcs( - seq_a: t.Sequence[T], seq_b: t.Sequence[T], equal: t.Callable[[T, T], bool] -) -> t.Sequence[t.Optional[T]]: - """Calculates the longest common subsequence""" - - len_a = len(seq_a) - len_b = len(seq_b) - lcs_result = [[None] * (len_b + 1) for i in range(len_a + 1)] - - for i in range(len_a + 1): - for j in range(len_b + 1): - if i == 0 or j == 0: - lcs_result[i][j] = [] # type: ignore - elif equal(seq_a[i - 1], seq_b[j - 1]): - lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] # type: ignore - else: - lcs_result[i][j] = ( - lcs_result[i - 1][j] - if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) # type: ignore - else lcs_result[i][j - 1] - ) - - return lcs_result[len_a][len_b] # type: ignore diff --git a/altimate_packages/sqlglot/errors.py b/altimate_packages/sqlglot/errors.py deleted file mode 100644 index 300c21574..000000000 --- a/altimate_packages/sqlglot/errors.py +++ /dev/null @@ -1,93 +0,0 @@ -from __future__ import annotations - -import typing as t -from enum import auto - -from sqlglot.helper import AutoName - - -class ErrorLevel(AutoName): - IGNORE = auto() - """Ignore all errors.""" - - WARN = auto() - """Log all errors.""" - - RAISE = auto() - """Collect all errors and raise a single exception.""" - - IMMEDIATE = auto() - """Immediately raise an exception on the first error found.""" - - -class SqlglotError(Exception): - pass - - -class UnsupportedError(SqlglotError): - pass - - -class ParseError(SqlglotError): - def __init__( - self, - message: str, - errors: t.Optional[t.List[t.Dict[str, t.Any]]] = None, - ): - super().__init__(message) - self.errors = errors or [] - - @classmethod - def new( - cls, - message: str, - description: t.Optional[str] = None, - line: t.Optional[int] = None, - col: t.Optional[int] = None, - start_context: t.Optional[str] = None, - highlight: t.Optional[str] = None, - end_context: t.Optional[str] = None, - into_expression: t.Optional[str] = None, - ) -> ParseError: - return cls( - message, - [ - { - "description": description, - "line": line, - "col": col, - "start_context": start_context, - "highlight": highlight, - "end_context": end_context, - "into_expression": into_expression, - } - ], - ) - - -class TokenError(SqlglotError): - pass - - -class OptimizeError(SqlglotError): - pass - - -class SchemaError(SqlglotError): - pass - - -class ExecuteError(SqlglotError): - pass - - -def concat_messages(errors: t.Sequence[t.Any], maximum: int) -> str: - msg = [str(e) for e in errors[:maximum]] - remaining = len(errors) - maximum - if remaining > 0: - msg.append(f"... and {remaining} more") - return "\n\n".join(msg) - - -def merge_errors(errors: t.Sequence[ParseError]) -> t.List[t.Dict[str, t.Any]]: - return [e_dict for error in errors for e_dict in error.errors] diff --git a/altimate_packages/sqlglot/executor/__init__.py b/altimate_packages/sqlglot/executor/__init__.py deleted file mode 100644 index 432fe3f00..000000000 --- a/altimate_packages/sqlglot/executor/__init__.py +++ /dev/null @@ -1,95 +0,0 @@ -""" -.. include:: ../../posts/python_sql_engine.md - ----- -""" - -from __future__ import annotations - -import logging -import time -import typing as t - -from sqlglot import exp -from sqlglot.errors import ExecuteError -from sqlglot.executor.python import PythonExecutor -from sqlglot.executor.table import Table, ensure_tables -from sqlglot.helper import dict_depth -from sqlglot.optimizer import optimize -from sqlglot.optimizer.annotate_types import annotate_types -from sqlglot.planner import Plan -from sqlglot.schema import ensure_schema, flatten_schema, nested_get, nested_set - -logger = logging.getLogger("sqlglot") - -if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import DialectType - from sqlglot.expressions import Expression - from sqlglot.schema import Schema - - -def execute( - sql: str | Expression, - schema: t.Optional[t.Dict | Schema] = None, - read: DialectType = None, - dialect: DialectType = None, - tables: t.Optional[t.Dict] = None, -) -> Table: - """ - Run a sql query against data. - - Args: - sql: a sql statement. - schema: database schema. - This can either be an instance of `Schema` or a mapping in one of the following forms: - 1. {table: {col: type}} - 2. {db: {table: {col: type}}} - 3. {catalog: {db: {table: {col: type}}}} - read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql"). - dialect: the SQL dialect (alias for read). - tables: additional tables to register. - - Returns: - Simple columnar data structure. - """ - read = read or dialect - tables_ = ensure_tables(tables, dialect=read) - - if not schema: - schema = {} - flattened_tables = flatten_schema(tables_.mapping, depth=dict_depth(tables_.mapping)) - - for keys in flattened_tables: - table = nested_get(tables_.mapping, *zip(keys, keys)) - assert table is not None - - for column in table.columns: - value = table[0][column] - column_type = ( - annotate_types(exp.convert(value), dialect=read).type or type(value).__name__ - ) - nested_set(schema, [*keys, column], column_type) - - schema = ensure_schema(schema, dialect=read) - - if tables_.supported_table_args and tables_.supported_table_args != schema.supported_table_args: - raise ExecuteError("Tables must support the same table args as schema") - - now = time.time() - expression = optimize( - sql, schema, leave_tables_isolated=True, infer_csv_schemas=True, dialect=read - ) - - logger.debug("Optimization finished: %f", time.time() - now) - logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) - - plan = Plan(expression) - - logger.debug("Logical Plan: %s", plan) - - now = time.time() - result = PythonExecutor(tables=tables_).execute(plan) - - logger.debug("Query finished: %f", time.time() - now) - - return result diff --git a/altimate_packages/sqlglot/executor/context.py b/altimate_packages/sqlglot/executor/context.py deleted file mode 100644 index a411c18f0..000000000 --- a/altimate_packages/sqlglot/executor/context.py +++ /dev/null @@ -1,101 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot.executor.env import ENV - -if t.TYPE_CHECKING: - from sqlglot.executor.table import Table, TableIter - - -class Context: - """ - Execution context for sql expressions. - - Context is used to hold relevant data tables which can then be queried on with eval. - - References to columns can either be scalar or vectors. When set_row is used, column references - evaluate to scalars while set_range evaluates to vectors. This allows convenient and efficient - evaluation of aggregation functions. - """ - - def __init__(self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None) -> None: - """ - Args - tables: representing the scope of the current execution context. - env: dictionary of functions within the execution context. - """ - self.tables = tables - self._table: t.Optional[Table] = None - self.range_readers = {name: table.range_reader for name, table in self.tables.items()} - self.row_readers = {name: table.reader for name, table in tables.items()} - self.env = {**ENV, **(env or {}), "scope": self.row_readers} - - def eval(self, code): - return eval(code, self.env) - - def eval_tuple(self, codes): - return tuple(self.eval(code) for code in codes) - - @property - def table(self) -> Table: - if self._table is None: - self._table = list(self.tables.values())[0] - - for other in self.tables.values(): - if self._table.columns != other.columns: - raise Exception("Columns are different.") - if len(self._table.rows) != len(other.rows): - raise Exception("Rows are different.") - - return self._table - - def add_columns(self, *columns: str) -> None: - for table in self.tables.values(): - table.add_columns(*columns) - - @property - def columns(self) -> t.Tuple: - return self.table.columns - - def __iter__(self): - self.env["scope"] = self.row_readers - for i in range(len(self.table.rows)): - for table in self.tables.values(): - reader = table[i] - yield reader, self - - def table_iter(self, table: str) -> TableIter: - self.env["scope"] = self.row_readers - return iter(self.tables[table]) - - def filter(self, condition) -> None: - rows = [reader.row for reader, _ in self if self.eval(condition)] - - for table in self.tables.values(): - table.rows = rows - - def sort(self, key) -> None: - def sort_key(row: t.Tuple) -> t.Tuple: - self.set_row(row) - return tuple((t is None, t) for t in self.eval_tuple(key)) - - self.table.rows.sort(key=sort_key) - - def set_row(self, row: t.Tuple) -> None: - for table in self.tables.values(): - table.reader.row = row - self.env["scope"] = self.row_readers - - def set_index(self, index: int) -> None: - for table in self.tables.values(): - table[index] - self.env["scope"] = self.row_readers - - def set_range(self, start: int, end: int) -> None: - for name in self.tables: - self.range_readers[name].range = range(start, end) - self.env["scope"] = self.range_readers - - def __contains__(self, table: str) -> bool: - return table in self.tables diff --git a/altimate_packages/sqlglot/executor/env.py b/altimate_packages/sqlglot/executor/env.py deleted file mode 100644 index c6c8ee07e..000000000 --- a/altimate_packages/sqlglot/executor/env.py +++ /dev/null @@ -1,246 +0,0 @@ -import datetime -import inspect -import re -import statistics -from functools import wraps - -from sqlglot import exp -from sqlglot.generator import Generator -from sqlglot.helper import PYTHON_VERSION, is_int, seq_get - - -class reverse_key: - def __init__(self, obj): - self.obj = obj - - def __eq__(self, other): - return other.obj == self.obj - - def __lt__(self, other): - return other.obj < self.obj - - -def filter_nulls(func, empty_null=True): - @wraps(func) - def _func(values): - filtered = tuple(v for v in values if v is not None) - if not filtered and empty_null: - return None - return func(filtered) - - return _func - - -def null_if_any(*required): - """ - Decorator that makes a function return `None` if any of the `required` arguments are `None`. - - This also supports decoration with no arguments, e.g.: - - @null_if_any - def foo(a, b): ... - - In which case all arguments are required. - """ - f = None - if len(required) == 1 and callable(required[0]): - f = required[0] - required = () - - def decorator(func): - if required: - required_indices = [ - i for i, param in enumerate(inspect.signature(func).parameters) if param in required - ] - - def predicate(*args): - return any(args[i] is None for i in required_indices) - - else: - - def predicate(*args): - return any(a is None for a in args) - - @wraps(func) - def _func(*args): - if predicate(*args): - return None - return func(*args) - - return _func - - if f: - return decorator(f) - - return decorator - - -@null_if_any("this", "substr") -def str_position(this, substr, position=None): - position = position - 1 if position is not None else position - return this.find(substr, position) + 1 - - -@null_if_any("this") -def substring(this, start=None, length=None): - if start is None: - return this - elif start == 0: - return "" - elif start < 0: - start = len(this) + start - else: - start -= 1 - - end = None if length is None else start + length - - return this[start:end] - - -@null_if_any -def cast(this, to): - if to == exp.DataType.Type.DATE: - if isinstance(this, datetime.datetime): - return this.date() - if isinstance(this, datetime.date): - return this - if isinstance(this, str): - return datetime.date.fromisoformat(this) - if to == exp.DataType.Type.TIME: - if isinstance(this, datetime.datetime): - return this.time() - if isinstance(this, datetime.time): - return this - if isinstance(this, str): - return datetime.time.fromisoformat(this) - if to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP): - if isinstance(this, datetime.datetime): - return this - if isinstance(this, datetime.date): - return datetime.datetime(this.year, this.month, this.day) - if isinstance(this, str): - return datetime.datetime.fromisoformat(this) - if to == exp.DataType.Type.BOOLEAN: - return bool(this) - if to in exp.DataType.TEXT_TYPES: - return str(this) - if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}: - return float(this) - if to in exp.DataType.NUMERIC_TYPES: - return int(this) - raise NotImplementedError(f"Casting {this} to '{to}' not implemented.") - - -def ordered(this, desc, nulls_first): - if desc: - return reverse_key(this) - return this - - -@null_if_any -def interval(this, unit): - plural = unit + "S" - if plural in Generator.TIME_PART_SINGULARS: - unit = plural - return datetime.timedelta(**{unit.lower(): float(this)}) - - -@null_if_any("this", "expression") -def arraytostring(this, expression, null=None): - return expression.join(x for x in (x if x is not None else null for x in this) if x is not None) - - -@null_if_any("this", "expression") -def jsonextract(this, expression): - for path_segment in expression: - if isinstance(this, dict): - this = this.get(path_segment) - elif isinstance(this, list) and is_int(path_segment): - this = seq_get(this, int(path_segment)) - else: - raise NotImplementedError(f"Unable to extract value for {this} at {path_segment}.") - - if this is None: - break - - return this - - -ENV = { - "exp": exp, - # aggs - "ARRAYAGG": list, - "ARRAYUNIQUEAGG": filter_nulls(lambda acc: list(set(acc))), - "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore - "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False), - "MAX": filter_nulls(max), - "MIN": filter_nulls(min), - "SUM": filter_nulls(sum), - # scalar functions - "ABS": null_if_any(lambda this: abs(this)), - "ADD": null_if_any(lambda e, this: e + this), - "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)), - "ARRAYTOSTRING": arraytostring, - "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high), - "BITWISEAND": null_if_any(lambda this, e: this & e), - "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e), - "BITWISEOR": null_if_any(lambda this, e: this | e), - "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e), - "BITWISEXOR": null_if_any(lambda this, e: this ^ e), - "CAST": cast, - "COALESCE": lambda *args: next((a for a in args if a is not None), None), - "CONCAT": null_if_any(lambda *args: "".join(args)), - "SAFECONCAT": null_if_any(lambda *args: "".join(str(arg) for arg in args)), - "CONCATWS": null_if_any(lambda this, *args: this.join(args)), - "DATEDIFF": null_if_any(lambda this, expression, *_: (this - expression).days), - "DATESTRTODATE": null_if_any(lambda arg: datetime.date.fromisoformat(arg)), - "DIV": null_if_any(lambda e, this: e / this), - "DOT": null_if_any(lambda e, this: e[this]), - "EQ": null_if_any(lambda this, e: this == e), - "EXTRACT": null_if_any(lambda this, e: getattr(e, this)), - "GT": null_if_any(lambda this, e: this > e), - "GTE": null_if_any(lambda this, e: this >= e), - "IF": lambda predicate, true, false: true if predicate else false, - "INTDIV": null_if_any(lambda e, this: e // this), - "INTERVAL": interval, - "JSONEXTRACT": jsonextract, - "LEFT": null_if_any(lambda this, e: this[:e]), - "LIKE": null_if_any( - lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this)) - ), - "LOWER": null_if_any(lambda arg: arg.lower()), - "LT": null_if_any(lambda this, e: this < e), - "LTE": null_if_any(lambda this, e: this <= e), - "MAP": null_if_any(lambda *args: dict(zip(*args))), # type: ignore - "MOD": null_if_any(lambda e, this: e % this), - "MUL": null_if_any(lambda e, this: e * this), - "NEQ": null_if_any(lambda this, e: this != e), - "ORD": null_if_any(ord), - "ORDERED": ordered, - "POW": pow, - "RIGHT": null_if_any(lambda this, e: this[-e:]), - "ROUND": null_if_any(lambda this, decimals=None, truncate=None: round(this, ndigits=decimals)), - "STRPOSITION": str_position, - "SUB": null_if_any(lambda e, this: e - this), - "SUBSTRING": substring, - "TIMESTRTOTIME": null_if_any(lambda arg: datetime.datetime.fromisoformat(arg)), - "UPPER": null_if_any(lambda arg: arg.upper()), - "YEAR": null_if_any(lambda arg: arg.year), - "MONTH": null_if_any(lambda arg: arg.month), - "DAY": null_if_any(lambda arg: arg.day), - "CURRENTDATETIME": datetime.datetime.now, - "CURRENTTIMESTAMP": datetime.datetime.now, - "CURRENTTIME": datetime.datetime.now, - "CURRENTDATE": datetime.date.today, - "STRFTIME": null_if_any(lambda fmt, arg: datetime.datetime.fromisoformat(arg).strftime(fmt)), - "STRTOTIME": null_if_any(lambda arg, format: datetime.datetime.strptime(arg, format)), - "TRIM": null_if_any(lambda this, e=None: this.strip(e)), - "STRUCT": lambda *args: { - args[x]: args[x + 1] - for x in range(0, len(args), 2) - if (args[x + 1] is not None and args[x] is not None) - }, - "UNIXTOTIME": null_if_any( - lambda arg: datetime.datetime.fromtimestamp(arg, datetime.timezone.utc) - ), -} diff --git a/altimate_packages/sqlglot/executor/python.py b/altimate_packages/sqlglot/executor/python.py deleted file mode 100644 index ce7a2a872..000000000 --- a/altimate_packages/sqlglot/executor/python.py +++ /dev/null @@ -1,460 +0,0 @@ -import ast -import collections -import itertools -import math - -from sqlglot import exp, generator, planner, tokens -from sqlglot.dialects.dialect import Dialect, inline_array_sql -from sqlglot.errors import ExecuteError -from sqlglot.executor.context import Context -from sqlglot.executor.env import ENV -from sqlglot.executor.table import RowReader, Table -from sqlglot.helper import csv_reader, subclasses - - -class PythonExecutor: - def __init__(self, env=None, tables=None): - self.generator = Python().generator(identify=True, comments=False) - self.env = {**ENV, **(env or {})} - self.tables = tables or {} - - def execute(self, plan): - finished = set() - queue = set(plan.leaves) - contexts = {} - - while queue: - node = queue.pop() - try: - context = self.context( - { - name: table - for dep in node.dependencies - for name, table in contexts[dep].tables.items() - } - ) - - if isinstance(node, planner.Scan): - contexts[node] = self.scan(node, context) - elif isinstance(node, planner.Aggregate): - contexts[node] = self.aggregate(node, context) - elif isinstance(node, planner.Join): - contexts[node] = self.join(node, context) - elif isinstance(node, planner.Sort): - contexts[node] = self.sort(node, context) - elif isinstance(node, planner.SetOperation): - contexts[node] = self.set_operation(node, context) - else: - raise NotImplementedError - - finished.add(node) - - for dep in node.dependents: - if all(d in contexts for d in dep.dependencies): - queue.add(dep) - - for dep in node.dependencies: - if all(d in finished for d in dep.dependents): - contexts.pop(dep) - except Exception as e: - raise ExecuteError(f"Step '{node.id}' failed: {e}") from e - - root = plan.root - return contexts[root].tables[root.name] - - def generate(self, expression): - """Convert a SQL expression into literal Python code and compile it into bytecode.""" - if not expression: - return None - - sql = self.generator.generate(expression) - return compile(sql, sql, "eval", optimize=2) - - def generate_tuple(self, expressions): - """Convert an array of SQL expressions into tuple of Python byte code.""" - if not expressions: - return tuple() - return tuple(self.generate(expression) for expression in expressions) - - def context(self, tables): - return Context(tables, env=self.env) - - def table(self, expressions): - return Table( - expression.alias_or_name if isinstance(expression, exp.Expression) else expression - for expression in expressions - ) - - def scan(self, step, context): - source = step.source - - if source and isinstance(source, exp.Expression): - source = source.name or source.alias - - if source is None: - context, table_iter = self.static() - elif source in context: - if not step.projections and not step.condition: - return self.context({step.name: context.tables[source]}) - table_iter = context.table_iter(source) - elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV): - table_iter = self.scan_csv(step) - context = next(table_iter) - else: - context, table_iter = self.scan_table(step) - - return self.context({step.name: self._project_and_filter(context, step, table_iter)}) - - def _project_and_filter(self, context, step, table_iter): - sink = self.table(step.projections if step.projections else context.columns) - condition = self.generate(step.condition) - projections = self.generate_tuple(step.projections) - - for reader in table_iter: - if len(sink) >= step.limit: - break - - if condition and not context.eval(condition): - continue - - if projections: - sink.append(context.eval_tuple(projections)) - else: - sink.append(reader.row) - - return sink - - def static(self): - return self.context({}), [RowReader(())] - - def scan_table(self, step): - table = self.tables.find(step.source) - context = self.context({step.source.alias_or_name: table}) - return context, iter(table) - - def scan_csv(self, step): - alias = step.source.alias - source = step.source.this - - with csv_reader(source) as reader: - columns = next(reader) - table = Table(columns) - context = self.context({alias: table}) - yield context - types = [] - for row in reader: - if not types: - for v in row: - try: - types.append(type(ast.literal_eval(v))) - except (ValueError, SyntaxError): - types.append(str) - - # We can't cast empty values ('') to non-string types, so we convert them to None instead - context.set_row( - tuple(None if (t is not str and v == "") else t(v) for t, v in zip(types, row)) - ) - yield context.table.reader - - def join(self, step, context): - source = step.source_name - - source_table = context.tables[source] - source_context = self.context({source: source_table}) - column_ranges = {source: range(0, len(source_table.columns))} - - for name, join in step.joins.items(): - table = context.tables[name] - start = max(r.stop for r in column_ranges.values()) - column_ranges[name] = range(start, len(table.columns) + start) - join_context = self.context({name: table}) - - if join.get("source_key"): - table = self.hash_join(join, source_context, join_context) - else: - table = self.nested_loop_join(join, source_context, join_context) - - source_context = self.context( - { - name: Table(table.columns, table.rows, column_range) - for name, column_range in column_ranges.items() - } - ) - condition = self.generate(join["condition"]) - if condition: - source_context.filter(condition) - - if not step.condition and not step.projections: - return source_context - - sink = self._project_and_filter( - source_context, - step, - (reader for reader, _ in iter(source_context)), - ) - - if step.projections: - return self.context({step.name: sink}) - else: - return self.context( - { - name: Table(table.columns, sink.rows, table.column_range) - for name, table in source_context.tables.items() - } - ) - - def nested_loop_join(self, _join, source_context, join_context): - table = Table(source_context.columns + join_context.columns) - - for reader_a, _ in source_context: - for reader_b, _ in join_context: - table.append(reader_a.row + reader_b.row) - - return table - - def hash_join(self, join, source_context, join_context): - source_key = self.generate_tuple(join["source_key"]) - join_key = self.generate_tuple(join["join_key"]) - left = join.get("side") == "LEFT" - right = join.get("side") == "RIGHT" - - results = collections.defaultdict(lambda: ([], [])) - - for reader, ctx in source_context: - results[ctx.eval_tuple(source_key)][0].append(reader.row) - for reader, ctx in join_context: - results[ctx.eval_tuple(join_key)][1].append(reader.row) - - table = Table(source_context.columns + join_context.columns) - nulls = [(None,) * len(join_context.columns if left else source_context.columns)] - - for a_group, b_group in results.values(): - if left: - b_group = b_group or nulls - elif right: - a_group = a_group or nulls - - for a_row, b_row in itertools.product(a_group, b_group): - table.append(a_row + b_row) - - return table - - def aggregate(self, step, context): - group_by = self.generate_tuple(step.group.values()) - aggregations = self.generate_tuple(step.aggregations) - operands = self.generate_tuple(step.operands) - - if operands: - operand_table = Table(self.table(step.operands).columns) - - for reader, ctx in context: - operand_table.append(ctx.eval_tuple(operands)) - - for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)): - context.table.rows[i] = a + b - - width = len(context.columns) - context.add_columns(*operand_table.columns) - - operand_table = Table( - context.columns, - context.table.rows, - range(width, width + len(operand_table.columns)), - ) - - context = self.context( - { - None: operand_table, - **context.tables, - } - ) - - context.sort(group_by) - - group = None - start = 0 - end = 1 - length = len(context.table) - table = self.table(list(step.group) + step.aggregations) - - def add_row(): - table.append(group + context.eval_tuple(aggregations)) - - if length: - for i in range(length): - context.set_index(i) - key = context.eval_tuple(group_by) - group = key if group is None else group - end += 1 - if key != group: - context.set_range(start, end - 2) - add_row() - group = key - start = end - 2 - if len(table.rows) >= step.limit: - break - if i == length - 1: - context.set_range(start, end - 1) - add_row() - elif step.limit > 0 and not group_by: - context.set_range(0, 0) - table.append(context.eval_tuple(aggregations)) - - context = self.context({step.name: table, **{name: table for name in context.tables}}) - - if step.projections or step.condition: - return self.scan(step, context) - return context - - def sort(self, step, context): - projections = self.generate_tuple(step.projections) - projection_columns = [p.alias_or_name for p in step.projections] - all_columns = list(context.columns) + projection_columns - sink = self.table(all_columns) - for reader, ctx in context: - sink.append(reader.row + ctx.eval_tuple(projections)) - - sort_ctx = self.context( - { - None: sink, - **{table: sink for table in context.tables}, - } - ) - sort_ctx.sort(self.generate_tuple(step.key)) - - if not math.isinf(step.limit): - sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit] - - output = Table( - projection_columns, - rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows], - ) - return self.context({step.name: output}) - - def set_operation(self, step, context): - left = context.tables[step.left] - right = context.tables[step.right] - - sink = self.table(left.columns) - - if issubclass(step.op, exp.Intersect): - sink.rows = list(set(left.rows).intersection(set(right.rows))) - elif issubclass(step.op, exp.Except): - sink.rows = list(set(left.rows).difference(set(right.rows))) - elif issubclass(step.op, exp.Union) and step.distinct: - sink.rows = list(set(left.rows).union(set(right.rows))) - else: - sink.rows = left.rows + right.rows - - if not math.isinf(step.limit): - sink.rows = sink.rows[0 : step.limit] - - return self.context({step.name: sink}) - - -def _ordered_py(self, expression): - this = self.sql(expression, "this") - desc = "True" if expression.args.get("desc") else "False" - nulls_first = "True" if expression.args.get("nulls_first") else "False" - return f"ORDERED({this}, {desc}, {nulls_first})" - - -def _rename(self, e): - try: - values = list(e.args.values()) - - if len(values) == 1: - values = values[0] - if not isinstance(values, list): - return self.func(e.key, values) - return self.func(e.key, *values) - - if isinstance(e, exp.Func) and e.is_var_len_args: - args = itertools.chain.from_iterable(x if isinstance(x, list) else [x] for x in values) - return self.func(e.key, *args) - - return self.func(e.key, *values) - except Exception as ex: - raise Exception(f"Could not rename {repr(e)}") from ex - - -def _case_sql(self, expression): - this = self.sql(expression, "this") - chain = self.sql(expression, "default") or "None" - - for e in reversed(expression.args["ifs"]): - true = self.sql(e, "true") - condition = self.sql(e, "this") - condition = f"{this} = ({condition})" if this else condition - chain = f"{true} if {condition} else ({chain})" - - return chain - - -def _lambda_sql(self, e: exp.Lambda) -> str: - names = {e.name.lower() for e in e.expressions} - - e = e.transform( - lambda n: ( - exp.var(n.name) if isinstance(n, exp.Identifier) and n.name.lower() in names else n - ) - ).assert_is(exp.Lambda) - - return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}" - - -def _div_sql(self: generator.Generator, e: exp.Div) -> str: - denominator = self.sql(e, "expression") - - if e.args.get("safe"): - denominator += " or None" - - sql = f"DIV({self.sql(e, 'this')}, {denominator})" - - if e.args.get("typed"): - sql = f"int({sql})" - - return sql - - -class Python(Dialect): - class Tokenizer(tokens.Tokenizer): - STRING_ESCAPES = ["\\"] - - class Generator(generator.Generator): - TRANSFORMS = { - **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)}, - **{klass: _rename for klass in exp.ALL_FUNCTIONS}, - exp.Case: _case_sql, - exp.Alias: lambda self, e: self.sql(e.this), - exp.Array: inline_array_sql, - exp.And: lambda self, e: self.binary(e, "and"), - exp.Between: _rename, - exp.Boolean: lambda self, e: "True" if e.this else "False", - exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})", - exp.Column: lambda self, - e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]", - exp.Concat: lambda self, e: self.func( - "SAFECONCAT" if e.args.get("safe") else "CONCAT", *e.expressions - ), - exp.Distinct: lambda self, e: f"set({self.sql(e, 'this')})", - exp.Div: _div_sql, - exp.Extract: lambda self, - e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", - exp.In: lambda self, - e: f"{self.sql(e, 'this')} in {{{self.expressions(e, flat=True)}}}", - exp.Interval: lambda self, e: f"INTERVAL({self.sql(e.this)}, '{self.sql(e.unit)}')", - exp.Is: lambda self, e: ( - self.binary(e, "==") if isinstance(e.this, exp.Literal) else self.binary(e, "is") - ), - exp.JSONExtract: lambda self, e: self.func(e.key, e.this, e.expression, *e.expressions), - exp.JSONPath: lambda self, e: f"[{','.join(self.sql(p) for p in e.expressions[1:])}]", - exp.JSONPathKey: lambda self, e: f"'{self.sql(e.this)}'", - exp.JSONPathSubscript: lambda self, e: f"'{e.this}'", - exp.Lambda: _lambda_sql, - exp.Not: lambda self, e: f"not {self.sql(e.this)}", - exp.Null: lambda *_: "None", - exp.Or: lambda self, e: self.binary(e, "or"), - exp.Ordered: _ordered_py, - exp.Star: lambda *_: "1", - } diff --git a/altimate_packages/sqlglot/executor/table.py b/altimate_packages/sqlglot/executor/table.py deleted file mode 100644 index 1613ea30d..000000000 --- a/altimate_packages/sqlglot/executor/table.py +++ /dev/null @@ -1,155 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot.dialects.dialect import DialectType -from sqlglot.helper import dict_depth -from sqlglot.schema import AbstractMappingSchema, normalize_name - - -class Table: - def __init__( - self, - columns: t.Iterable, - rows: t.Optional[t.List] = None, - column_range: t.Optional[range] = None, - ) -> None: - self.columns = tuple(columns) - self.column_range = column_range - self.reader = RowReader(self.columns, self.column_range) - self.rows = rows or [] - if rows: - assert len(rows[0]) == len(self.columns) - self.range_reader = RangeReader(self) - - def add_columns(self, *columns: str) -> None: - self.columns += columns - if self.column_range: - self.column_range = range( - self.column_range.start, self.column_range.stop + len(columns) - ) - self.reader = RowReader(self.columns, self.column_range) - - def append(self, row: t.List) -> None: - assert len(row) == len(self.columns) - self.rows.append(row) - - def pop(self) -> None: - self.rows.pop() - - def to_pylist(self) -> t.List: - return [dict(zip(self.columns, row)) for row in self.rows] - - @property - def width(self) -> int: - return len(self.columns) - - def __len__(self) -> int: - return len(self.rows) - - def __iter__(self) -> TableIter: - return TableIter(self) - - def __getitem__(self, index: int) -> RowReader: - self.reader.row = self.rows[index] - return self.reader - - def __repr__(self) -> str: - columns = tuple( - column - for i, column in enumerate(self.columns) - if not self.column_range or i in self.column_range - ) - widths = {column: len(column) for column in columns} - lines = [" ".join(column for column in columns)] - - for i, row in enumerate(self): - if i > 10: - break - - lines.append( - " ".join( - str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns - ) - ) - return "\n".join(lines) - - -class TableIter: - def __init__(self, table: Table) -> None: - self.table = table - self.index = -1 - - def __iter__(self) -> TableIter: - return self - - def __next__(self) -> RowReader: - self.index += 1 - if self.index < len(self.table): - return self.table[self.index] - raise StopIteration - - -class RangeReader: - def __init__(self, table: Table) -> None: - self.table = table - self.range = range(0) - - def __len__(self) -> int: - return len(self.range) - - def __getitem__(self, column: str): - return (self.table[i][column] for i in self.range) - - -class RowReader: - def __init__(self, columns, column_range=None): - self.columns = { - column: i for i, column in enumerate(columns) if not column_range or i in column_range - } - self.row = None - - def __getitem__(self, column): - return self.row[self.columns[column]] - - -class Tables(AbstractMappingSchema): - pass - - -def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables: - return Tables(_ensure_tables(d, dialect=dialect)) - - -def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict: - if not d: - return {} - - depth = dict_depth(d) - if depth > 1: - return { - normalize_name(k, dialect=dialect, is_table=True).name: _ensure_tables( - v, dialect=dialect - ) - for k, v in d.items() - } - - result = {} - for table_name, table in d.items(): - table_name = normalize_name(table_name, dialect=dialect).name - - if isinstance(table, Table): - result[table_name] = table - else: - table = [ - { - normalize_name(column_name, dialect=dialect).name: value - for column_name, value in row.items() - } - for row in table - ] - column_names = tuple(column_name for column_name in table[0]) if table else () - rows = [tuple(row[name] for name in column_names) for row in table] - result[table_name] = Table(columns=column_names, rows=rows) - - return result diff --git a/altimate_packages/sqlglot/expressions.py b/altimate_packages/sqlglot/expressions.py deleted file mode 100644 index 35801be7c..000000000 --- a/altimate_packages/sqlglot/expressions.py +++ /dev/null @@ -1,8870 +0,0 @@ -""" -## Expressions - -Every AST node in SQLGlot is represented by a subclass of `Expression`. - -This module contains the implementation of all supported `Expression` types. Additionally, -it exposes a number of helper functions, which are mainly used to programmatically build -SQL expressions, such as `sqlglot.expressions.select`. - ----- -""" - -from __future__ import annotations - -import datetime -import math -import numbers -import re -import textwrap -import typing as t -from collections import deque -from copy import deepcopy -from decimal import Decimal -from enum import auto -from functools import reduce - -from sqlglot.errors import ErrorLevel, ParseError -from sqlglot.helper import ( - AutoName, - camel_to_snake_case, - ensure_collection, - ensure_list, - seq_get, - split_num_words, - subclasses, - to_bool, -) -from sqlglot.tokens import Token, TokenError - -if t.TYPE_CHECKING: - from typing_extensions import Self - - from sqlglot._typing import E, Lit - from sqlglot.dialects.dialect import DialectType - - Q = t.TypeVar("Q", bound="Query") - S = t.TypeVar("S", bound="SetOperation") - - -class _Expression(type): - def __new__(cls, clsname, bases, attrs): - klass = super().__new__(cls, clsname, bases, attrs) - - # When an Expression class is created, its key is automatically set - # to be the lowercase version of the class' name. - klass.key = clsname.lower() - - # This is so that docstrings are not inherited in pdoc - klass.__doc__ = klass.__doc__ or "" - - return klass - - -SQLGLOT_META = "sqlglot.meta" -SQLGLOT_ANONYMOUS = "sqlglot.anonymous" -TABLE_PARTS = ("this", "db", "catalog") -COLUMN_PARTS = ("this", "table", "db", "catalog") -POSITION_META_KEYS = ("line", "col", "start", "end") - - -class Expression(metaclass=_Expression): - """ - The base class for all expressions in a syntax tree. Each Expression encapsulates any necessary - context, such as its child expressions, their names (arg keys), and whether a given child expression - is optional or not. - - Attributes: - key: a unique key for each class in the Expression hierarchy. This is useful for hashing - and representing expressions as strings. - arg_types: determines the arguments (child nodes) supported by an expression. It maps - arg keys to booleans that indicate whether the corresponding args are optional. - parent: a reference to the parent expression (or None, in case of root expressions). - arg_key: the arg key an expression is associated with, i.e. the name its parent expression - uses to refer to it. - index: the index of an expression if it is inside of a list argument in its parent. - comments: a list of comments that are associated with a given expression. This is used in - order to preserve comments when transpiling SQL code. - type: the `sqlglot.expressions.DataType` type of an expression. This is inferred by the - optimizer, in order to enable some transformations that require type information. - meta: a dictionary that can be used to store useful metadata for a given expression. - - Example: - >>> class Foo(Expression): - ... arg_types = {"this": True, "expression": False} - - The above definition informs us that Foo is an Expression that requires an argument called - "this" and may also optionally receive an argument called "expression". - - Args: - args: a mapping used for retrieving the arguments of an expression, given their arg keys. - """ - - key = "expression" - arg_types = {"this": True} - __slots__ = ("args", "parent", "arg_key", "index", "comments", "_type", "_meta", "_hash") - - def __init__(self, **args: t.Any): - self.args: t.Dict[str, t.Any] = args - self.parent: t.Optional[Expression] = None - self.arg_key: t.Optional[str] = None - self.index: t.Optional[int] = None - self.comments: t.Optional[t.List[str]] = None - self._type: t.Optional[DataType] = None - self._meta: t.Optional[t.Dict[str, t.Any]] = None - self._hash: t.Optional[int] = None - - for arg_key, value in self.args.items(): - self._set_parent(arg_key, value) - - def __eq__(self, other) -> bool: - return type(self) is type(other) and hash(self) == hash(other) - - @property - def hashable_args(self) -> t.Any: - return frozenset( - (k, tuple(_norm_arg(a) for a in v) if type(v) is list else _norm_arg(v)) - for k, v in self.args.items() - if not (v is None or v is False or (type(v) is list and not v)) - ) - - def __hash__(self) -> int: - if self._hash is not None: - return self._hash - - return hash((self.__class__, self.hashable_args)) - - @property - def this(self) -> t.Any: - """ - Retrieves the argument with key "this". - """ - return self.args.get("this") - - @property - def expression(self) -> t.Any: - """ - Retrieves the argument with key "expression". - """ - return self.args.get("expression") - - @property - def expressions(self) -> t.List[t.Any]: - """ - Retrieves the argument with key "expressions". - """ - return self.args.get("expressions") or [] - - def text(self, key) -> str: - """ - Returns a textual representation of the argument corresponding to "key". This can only be used - for args that are strings or leaf Expression instances, such as identifiers and literals. - """ - field = self.args.get(key) - if isinstance(field, str): - return field - if isinstance(field, (Identifier, Literal, Var)): - return field.this - if isinstance(field, (Star, Null)): - return field.name - return "" - - @property - def is_string(self) -> bool: - """ - Checks whether a Literal expression is a string. - """ - return isinstance(self, Literal) and self.args["is_string"] - - @property - def is_number(self) -> bool: - """ - Checks whether a Literal expression is a number. - """ - return (isinstance(self, Literal) and not self.args["is_string"]) or ( - isinstance(self, Neg) and self.this.is_number - ) - - def to_py(self) -> t.Any: - """ - Returns a Python object equivalent of the SQL node. - """ - raise ValueError(f"{self} cannot be converted to a Python object.") - - @property - def is_int(self) -> bool: - """ - Checks whether an expression is an integer. - """ - return self.is_number and isinstance(self.to_py(), int) - - @property - def is_star(self) -> bool: - """Checks whether an expression is a star.""" - return isinstance(self, Star) or (isinstance(self, Column) and isinstance(self.this, Star)) - - @property - def alias(self) -> str: - """ - Returns the alias of the expression, or an empty string if it's not aliased. - """ - if isinstance(self.args.get("alias"), TableAlias): - return self.args["alias"].name - return self.text("alias") - - @property - def alias_column_names(self) -> t.List[str]: - table_alias = self.args.get("alias") - if not table_alias: - return [] - return [c.name for c in table_alias.args.get("columns") or []] - - @property - def name(self) -> str: - return self.text("this") - - @property - def alias_or_name(self) -> str: - return self.alias or self.name - - @property - def output_name(self) -> str: - """ - Name of the output column if this expression is a selection. - - If the Expression has no output name, an empty string is returned. - - Example: - >>> from sqlglot import parse_one - >>> parse_one("SELECT a").expressions[0].output_name - 'a' - >>> parse_one("SELECT b AS c").expressions[0].output_name - 'c' - >>> parse_one("SELECT 1 + 2").expressions[0].output_name - '' - """ - return "" - - @property - def type(self) -> t.Optional[DataType]: - return self._type - - @type.setter - def type(self, dtype: t.Optional[DataType | DataType.Type | str]) -> None: - if dtype and not isinstance(dtype, DataType): - dtype = DataType.build(dtype) - self._type = dtype # type: ignore - - def is_type(self, *dtypes) -> bool: - return self.type is not None and self.type.is_type(*dtypes) - - def is_leaf(self) -> bool: - return not any(isinstance(v, (Expression, list)) for v in self.args.values()) - - @property - def meta(self) -> t.Dict[str, t.Any]: - if self._meta is None: - self._meta = {} - return self._meta - - def __deepcopy__(self, memo): - root = self.__class__() - stack = [(self, root)] - - while stack: - node, copy = stack.pop() - - if node.comments is not None: - copy.comments = deepcopy(node.comments) - if node._type is not None: - copy._type = deepcopy(node._type) - if node._meta is not None: - copy._meta = deepcopy(node._meta) - if node._hash is not None: - copy._hash = node._hash - - for k, vs in node.args.items(): - if hasattr(vs, "parent"): - stack.append((vs, vs.__class__())) - copy.set(k, stack[-1][-1]) - elif type(vs) is list: - copy.args[k] = [] - - for v in vs: - if hasattr(v, "parent"): - stack.append((v, v.__class__())) - copy.append(k, stack[-1][-1]) - else: - copy.append(k, v) - else: - copy.args[k] = vs - - return root - - def copy(self) -> Self: - """ - Returns a deep copy of the expression. - """ - return deepcopy(self) - - def add_comments(self, comments: t.Optional[t.List[str]] = None, prepend: bool = False) -> None: - if self.comments is None: - self.comments = [] - - if comments: - for comment in comments: - _, *meta = comment.split(SQLGLOT_META) - if meta: - for kv in "".join(meta).split(","): - k, *v = kv.split("=") - value = v[0].strip() if v else True - self.meta[k.strip()] = to_bool(value) - - if not prepend: - self.comments.append(comment) - - if prepend: - self.comments = comments + self.comments - - def pop_comments(self) -> t.List[str]: - comments = self.comments or [] - self.comments = None - return comments - - def append(self, arg_key: str, value: t.Any) -> None: - """ - Appends value to arg_key if it's a list or sets it as a new list. - - Args: - arg_key (str): name of the list expression arg - value (Any): value to append to the list - """ - if type(self.args.get(arg_key)) is not list: - self.args[arg_key] = [] - self._set_parent(arg_key, value) - values = self.args[arg_key] - if hasattr(value, "parent"): - value.index = len(values) - values.append(value) - - def set( - self, - arg_key: str, - value: t.Any, - index: t.Optional[int] = None, - overwrite: bool = True, - ) -> None: - """ - Sets arg_key to value. - - Args: - arg_key: name of the expression arg. - value: value to set the arg to. - index: if the arg is a list, this specifies what position to add the value in it. - overwrite: assuming an index is given, this determines whether to overwrite the - list entry instead of only inserting a new value (i.e., like list.insert). - """ - if index is not None: - expressions = self.args.get(arg_key) or [] - - if seq_get(expressions, index) is None: - return - if value is None: - expressions.pop(index) - for v in expressions[index:]: - v.index = v.index - 1 - return - - if isinstance(value, list): - expressions.pop(index) - expressions[index:index] = value - elif overwrite: - expressions[index] = value - else: - expressions.insert(index, value) - - value = expressions - elif value is None: - self.args.pop(arg_key, None) - return - - self.args[arg_key] = value - self._set_parent(arg_key, value, index) - - def _set_parent(self, arg_key: str, value: t.Any, index: t.Optional[int] = None) -> None: - if hasattr(value, "parent"): - value.parent = self - value.arg_key = arg_key - value.index = index - elif type(value) is list: - for index, v in enumerate(value): - if hasattr(v, "parent"): - v.parent = self - v.arg_key = arg_key - v.index = index - - @property - def depth(self) -> int: - """ - Returns the depth of this tree. - """ - if self.parent: - return self.parent.depth + 1 - return 0 - - def iter_expressions(self, reverse: bool = False) -> t.Iterator[Expression]: - """Yields the key and expression for all arguments, exploding list args.""" - for vs in reversed(self.args.values()) if reverse else self.args.values(): # type: ignore - if type(vs) is list: - for v in reversed(vs) if reverse else vs: # type: ignore - if hasattr(v, "parent"): - yield v - else: - if hasattr(vs, "parent"): - yield vs - - def find(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Optional[E]: - """ - Returns the first node in this tree which matches at least one of - the specified types. - - Args: - expression_types: the expression type(s) to match. - bfs: whether to search the AST using the BFS algorithm (DFS is used if false). - - Returns: - The node which matches the criteria or None if no such node was found. - """ - return next(self.find_all(*expression_types, bfs=bfs), None) - - def find_all(self, *expression_types: t.Type[E], bfs: bool = True) -> t.Iterator[E]: - """ - Returns a generator object which visits all nodes in this tree and only - yields those that match at least one of the specified expression types. - - Args: - expression_types: the expression type(s) to match. - bfs: whether to search the AST using the BFS algorithm (DFS is used if false). - - Returns: - The generator object. - """ - for expression in self.walk(bfs=bfs): - if isinstance(expression, expression_types): - yield expression - - def find_ancestor(self, *expression_types: t.Type[E]) -> t.Optional[E]: - """ - Returns a nearest parent matching expression_types. - - Args: - expression_types: the expression type(s) to match. - - Returns: - The parent node. - """ - ancestor = self.parent - while ancestor and not isinstance(ancestor, expression_types): - ancestor = ancestor.parent - return ancestor # type: ignore - - @property - def parent_select(self) -> t.Optional[Select]: - """ - Returns the parent select statement. - """ - return self.find_ancestor(Select) - - @property - def same_parent(self) -> bool: - """Returns if the parent is the same class as itself.""" - return type(self.parent) is self.__class__ - - def root(self) -> Expression: - """ - Returns the root expression of this tree. - """ - expression = self - while expression.parent: - expression = expression.parent - return expression - - def walk( - self, bfs: bool = True, prune: t.Optional[t.Callable[[Expression], bool]] = None - ) -> t.Iterator[Expression]: - """ - Returns a generator object which visits all nodes in this tree. - - Args: - bfs: if set to True the BFS traversal order will be applied, - otherwise the DFS traversal will be used instead. - prune: callable that returns True if the generator should stop traversing - this branch of the tree. - - Returns: - the generator object. - """ - if bfs: - yield from self.bfs(prune=prune) - else: - yield from self.dfs(prune=prune) - - def dfs( - self, prune: t.Optional[t.Callable[[Expression], bool]] = None - ) -> t.Iterator[Expression]: - """ - Returns a generator object which visits all nodes in this tree in - the DFS (Depth-first) order. - - Returns: - The generator object. - """ - stack = [self] - - while stack: - node = stack.pop() - - yield node - - if prune and prune(node): - continue - - for v in node.iter_expressions(reverse=True): - stack.append(v) - - def bfs( - self, prune: t.Optional[t.Callable[[Expression], bool]] = None - ) -> t.Iterator[Expression]: - """ - Returns a generator object which visits all nodes in this tree in - the BFS (Breadth-first) order. - - Returns: - The generator object. - """ - queue = deque([self]) - - while queue: - node = queue.popleft() - - yield node - - if prune and prune(node): - continue - - for v in node.iter_expressions(): - queue.append(v) - - def unnest(self): - """ - Returns the first non parenthesis child or self. - """ - expression = self - while type(expression) is Paren: - expression = expression.this - return expression - - def unalias(self): - """ - Returns the inner expression if this is an Alias. - """ - if isinstance(self, Alias): - return self.this - return self - - def unnest_operands(self): - """ - Returns unnested operands as a tuple. - """ - return tuple(arg.unnest() for arg in self.iter_expressions()) - - def flatten(self, unnest=True): - """ - Returns a generator which yields child nodes whose parents are the same class. - - A AND B AND C -> [A, B, C] - """ - for node in self.dfs(prune=lambda n: n.parent and type(n) is not self.__class__): - if type(node) is not self.__class__: - yield node.unnest() if unnest and not isinstance(node, Subquery) else node - - def __str__(self) -> str: - return self.sql() - - def __repr__(self) -> str: - return _to_s(self) - - def to_s(self) -> str: - """ - Same as __repr__, but includes additional information which can be useful - for debugging, like empty or missing args and the AST nodes' object IDs. - """ - return _to_s(self, verbose=True) - - def sql(self, dialect: DialectType = None, **opts) -> str: - """ - Returns SQL string representation of this tree. - - Args: - dialect: the dialect of the output SQL string (eg. "spark", "hive", "presto", "mysql"). - opts: other `sqlglot.generator.Generator` options. - - Returns: - The SQL string. - """ - from sqlglot.dialects import Dialect - - return Dialect.get_or_raise(dialect).generate(self, **opts) - - def transform(self, fun: t.Callable, *args: t.Any, copy: bool = True, **kwargs) -> Expression: - """ - Visits all tree nodes (excluding already transformed ones) - and applies the given transformation function to each node. - - Args: - fun: a function which takes a node as an argument and returns a - new transformed node or the same node without modifications. If the function - returns None, then the corresponding node will be removed from the syntax tree. - copy: if set to True a new tree instance is constructed, otherwise the tree is - modified in place. - - Returns: - The transformed tree. - """ - root = None - new_node = None - - for node in (self.copy() if copy else self).dfs(prune=lambda n: n is not new_node): - parent, arg_key, index = node.parent, node.arg_key, node.index - new_node = fun(node, *args, **kwargs) - - if not root: - root = new_node - elif parent and arg_key and new_node is not node: - parent.set(arg_key, new_node, index) - - assert root - return root.assert_is(Expression) - - @t.overload - def replace(self, expression: E) -> E: ... - - @t.overload - def replace(self, expression: None) -> None: ... - - def replace(self, expression): - """ - Swap out this expression with a new expression. - - For example:: - - >>> tree = Select().select("x").from_("tbl") - >>> tree.find(Column).replace(column("y")) - Column( - this=Identifier(this=y, quoted=False)) - >>> tree.sql() - 'SELECT y FROM tbl' - - Args: - expression: new node - - Returns: - The new expression or expressions. - """ - parent = self.parent - - if not parent or parent is expression: - return expression - - key = self.arg_key - value = parent.args.get(key) - - if type(expression) is list and isinstance(value, Expression): - # We are trying to replace an Expression with a list, so it's assumed that - # the intention was to really replace the parent of this expression. - value.parent.replace(expression) - else: - parent.set(key, expression, self.index) - - if expression is not self: - self.parent = None - self.arg_key = None - self.index = None - - return expression - - def pop(self: E) -> E: - """ - Remove this expression from its AST. - - Returns: - The popped expression. - """ - self.replace(None) - return self - - def assert_is(self, type_: t.Type[E]) -> E: - """ - Assert that this `Expression` is an instance of `type_`. - - If it is NOT an instance of `type_`, this raises an assertion error. - Otherwise, this returns this expression. - - Examples: - This is useful for type security in chained expressions: - - >>> import sqlglot - >>> sqlglot.parse_one("SELECT x from y").assert_is(Select).select("z").sql() - 'SELECT x, z FROM y' - """ - if not isinstance(self, type_): - raise AssertionError(f"{self} is not {type_}.") - return self - - def error_messages(self, args: t.Optional[t.Sequence] = None) -> t.List[str]: - """ - Checks if this expression is valid (e.g. all mandatory args are set). - - Args: - args: a sequence of values that were used to instantiate a Func expression. This is used - to check that the provided arguments don't exceed the function argument limit. - - Returns: - A list of error messages for all possible errors that were found. - """ - errors: t.List[str] = [] - - for k in self.args: - if k not in self.arg_types: - errors.append(f"Unexpected keyword: '{k}' for {self.__class__}") - for k, mandatory in self.arg_types.items(): - v = self.args.get(k) - if mandatory and (v is None or (isinstance(v, list) and not v)): - errors.append(f"Required keyword: '{k}' missing for {self.__class__}") - - if ( - args - and isinstance(self, Func) - and len(args) > len(self.arg_types) - and not self.is_var_len_args - ): - errors.append( - f"The number of provided arguments ({len(args)}) is greater than " - f"the maximum number of supported arguments ({len(self.arg_types)})" - ) - - return errors - - def dump(self): - """ - Dump this Expression to a JSON-serializable dict. - """ - from sqlglot.serde import dump - - return dump(self) - - @classmethod - def load(cls, obj): - """ - Load a dict (as returned by `Expression.dump`) into an Expression instance. - """ - from sqlglot.serde import load - - return load(obj) - - def and_( - self, - *expressions: t.Optional[ExpOrStr], - dialect: DialectType = None, - copy: bool = True, - wrap: bool = True, - **opts, - ) -> Condition: - """ - AND this condition with one or multiple expressions. - - Example: - >>> condition("x=1").and_("y=1").sql() - 'x = 1 AND y = 1' - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - dialect: the dialect used to parse the input expression. - copy: whether to copy the involved expressions (only applies to Expressions). - wrap: whether to wrap the operands in `Paren`s. This is true by default to avoid - precedence issues, but can be turned off when the produced AST is too deep and - causes recursion-related issues. - opts: other options to use to parse the input expressions. - - Returns: - The new And condition. - """ - return and_(self, *expressions, dialect=dialect, copy=copy, wrap=wrap, **opts) - - def or_( - self, - *expressions: t.Optional[ExpOrStr], - dialect: DialectType = None, - copy: bool = True, - wrap: bool = True, - **opts, - ) -> Condition: - """ - OR this condition with one or multiple expressions. - - Example: - >>> condition("x=1").or_("y=1").sql() - 'x = 1 OR y = 1' - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - dialect: the dialect used to parse the input expression. - copy: whether to copy the involved expressions (only applies to Expressions). - wrap: whether to wrap the operands in `Paren`s. This is true by default to avoid - precedence issues, but can be turned off when the produced AST is too deep and - causes recursion-related issues. - opts: other options to use to parse the input expressions. - - Returns: - The new Or condition. - """ - return or_(self, *expressions, dialect=dialect, copy=copy, wrap=wrap, **opts) - - def not_(self, copy: bool = True): - """ - Wrap this condition with NOT. - - Example: - >>> condition("x=1").not_().sql() - 'NOT x = 1' - - Args: - copy: whether to copy this object. - - Returns: - The new Not instance. - """ - return not_(self, copy=copy) - - def update_positions( - self: E, other: t.Optional[Token | Expression] = None, **kwargs: t.Any - ) -> E: - """ - Update this expression with positions from a token or other expression. - - Args: - other: a token or expression to update this expression with. - - Returns: - The updated expression. - """ - if isinstance(other, Expression): - self.meta.update({k: v for k, v in other.meta.items() if k in POSITION_META_KEYS}) - elif other is not None: - self.meta.update( - { - "line": other.line, - "col": other.col, - "start": other.start, - "end": other.end, - } - ) - self.meta.update({k: v for k, v in kwargs.items() if k in POSITION_META_KEYS}) - return self - - def as_( - self, - alias: str | Identifier, - quoted: t.Optional[bool] = None, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Alias: - return alias_(self, alias, quoted=quoted, dialect=dialect, copy=copy, **opts) - - def _binop(self, klass: t.Type[E], other: t.Any, reverse: bool = False) -> E: - this = self.copy() - other = convert(other, copy=True) - if not isinstance(this, klass) and not isinstance(other, klass): - this = _wrap(this, Binary) - other = _wrap(other, Binary) - if reverse: - return klass(this=other, expression=this) - return klass(this=this, expression=other) - - def __getitem__(self, other: ExpOrStr | t.Tuple[ExpOrStr]) -> Bracket: - return Bracket( - this=self.copy(), expressions=[convert(e, copy=True) for e in ensure_list(other)] - ) - - def __iter__(self) -> t.Iterator: - if "expressions" in self.arg_types: - return iter(self.args.get("expressions") or []) - # We define this because __getitem__ converts Expression into an iterable, which is - # problematic because one can hit infinite loops if they do "for x in some_expr: ..." - # See: https://peps.python.org/pep-0234/ - raise TypeError(f"'{self.__class__.__name__}' object is not iterable") - - def isin( - self, - *expressions: t.Any, - query: t.Optional[ExpOrStr] = None, - unnest: t.Optional[ExpOrStr] | t.Collection[ExpOrStr] = None, - copy: bool = True, - **opts, - ) -> In: - subquery = maybe_parse(query, copy=copy, **opts) if query else None - if subquery and not isinstance(subquery, Subquery): - subquery = subquery.subquery(copy=False) - - return In( - this=maybe_copy(self, copy), - expressions=[convert(e, copy=copy) for e in expressions], - query=subquery, - unnest=( - Unnest( - expressions=[ - maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) - for e in ensure_list(unnest) - ] - ) - if unnest - else None - ), - ) - - def between(self, low: t.Any, high: t.Any, copy: bool = True, **opts) -> Between: - return Between( - this=maybe_copy(self, copy), - low=convert(low, copy=copy, **opts), - high=convert(high, copy=copy, **opts), - ) - - def is_(self, other: ExpOrStr) -> Is: - return self._binop(Is, other) - - def like(self, other: ExpOrStr) -> Like: - return self._binop(Like, other) - - def ilike(self, other: ExpOrStr) -> ILike: - return self._binop(ILike, other) - - def eq(self, other: t.Any) -> EQ: - return self._binop(EQ, other) - - def neq(self, other: t.Any) -> NEQ: - return self._binop(NEQ, other) - - def rlike(self, other: ExpOrStr) -> RegexpLike: - return self._binop(RegexpLike, other) - - def div(self, other: ExpOrStr, typed: bool = False, safe: bool = False) -> Div: - div = self._binop(Div, other) - div.args["typed"] = typed - div.args["safe"] = safe - return div - - def asc(self, nulls_first: bool = True) -> Ordered: - return Ordered(this=self.copy(), nulls_first=nulls_first) - - def desc(self, nulls_first: bool = False) -> Ordered: - return Ordered(this=self.copy(), desc=True, nulls_first=nulls_first) - - def __lt__(self, other: t.Any) -> LT: - return self._binop(LT, other) - - def __le__(self, other: t.Any) -> LTE: - return self._binop(LTE, other) - - def __gt__(self, other: t.Any) -> GT: - return self._binop(GT, other) - - def __ge__(self, other: t.Any) -> GTE: - return self._binop(GTE, other) - - def __add__(self, other: t.Any) -> Add: - return self._binop(Add, other) - - def __radd__(self, other: t.Any) -> Add: - return self._binop(Add, other, reverse=True) - - def __sub__(self, other: t.Any) -> Sub: - return self._binop(Sub, other) - - def __rsub__(self, other: t.Any) -> Sub: - return self._binop(Sub, other, reverse=True) - - def __mul__(self, other: t.Any) -> Mul: - return self._binop(Mul, other) - - def __rmul__(self, other: t.Any) -> Mul: - return self._binop(Mul, other, reverse=True) - - def __truediv__(self, other: t.Any) -> Div: - return self._binop(Div, other) - - def __rtruediv__(self, other: t.Any) -> Div: - return self._binop(Div, other, reverse=True) - - def __floordiv__(self, other: t.Any) -> IntDiv: - return self._binop(IntDiv, other) - - def __rfloordiv__(self, other: t.Any) -> IntDiv: - return self._binop(IntDiv, other, reverse=True) - - def __mod__(self, other: t.Any) -> Mod: - return self._binop(Mod, other) - - def __rmod__(self, other: t.Any) -> Mod: - return self._binop(Mod, other, reverse=True) - - def __pow__(self, other: t.Any) -> Pow: - return self._binop(Pow, other) - - def __rpow__(self, other: t.Any) -> Pow: - return self._binop(Pow, other, reverse=True) - - def __and__(self, other: t.Any) -> And: - return self._binop(And, other) - - def __rand__(self, other: t.Any) -> And: - return self._binop(And, other, reverse=True) - - def __or__(self, other: t.Any) -> Or: - return self._binop(Or, other) - - def __ror__(self, other: t.Any) -> Or: - return self._binop(Or, other, reverse=True) - - def __neg__(self) -> Neg: - return Neg(this=_wrap(self.copy(), Binary)) - - def __invert__(self) -> Not: - return not_(self.copy()) - - -IntoType = t.Union[ - str, - t.Type[Expression], - t.Collection[t.Union[str, t.Type[Expression]]], -] -ExpOrStr = t.Union[str, Expression] - - -class Condition(Expression): - """Logical conditions like x AND y, or simply x""" - - -class Predicate(Condition): - """Relationships like x = y, x > 1, x >= y.""" - - -class DerivedTable(Expression): - @property - def selects(self) -> t.List[Expression]: - return self.this.selects if isinstance(self.this, Query) else [] - - @property - def named_selects(self) -> t.List[str]: - return [select.output_name for select in self.selects] - - -class Query(Expression): - def subquery(self, alias: t.Optional[ExpOrStr] = None, copy: bool = True) -> Subquery: - """ - Returns a `Subquery` that wraps around this query. - - Example: - >>> subquery = Select().select("x").from_("tbl").subquery() - >>> Select().select("x").from_(subquery).sql() - 'SELECT x FROM (SELECT x FROM tbl)' - - Args: - alias: an optional alias for the subquery. - copy: if `False`, modify this expression instance in-place. - """ - instance = maybe_copy(self, copy) - if not isinstance(alias, Expression): - alias = TableAlias(this=to_identifier(alias)) if alias else None - - return Subquery(this=instance, alias=alias) - - def limit( - self: Q, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts - ) -> Q: - """ - Adds a LIMIT clause to this query. - - Example: - >>> select("1").union(select("1")).limit(1).sql() - 'SELECT 1 UNION SELECT 1 LIMIT 1' - - Args: - expression: the SQL code string to parse. - This can also be an integer. - If a `Limit` instance is passed, it will be used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Limit`. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - A limited Select expression. - """ - return _apply_builder( - expression=expression, - instance=self, - arg="limit", - into=Limit, - prefix="LIMIT", - dialect=dialect, - copy=copy, - into_arg="expression", - **opts, - ) - - def offset( - self: Q, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts - ) -> Q: - """ - Set the OFFSET expression. - - Example: - >>> Select().from_("tbl").select("x").offset(10).sql() - 'SELECT x FROM tbl OFFSET 10' - - Args: - expression: the SQL code string to parse. - This can also be an integer. - If a `Offset` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Offset`. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - return _apply_builder( - expression=expression, - instance=self, - arg="offset", - into=Offset, - prefix="OFFSET", - dialect=dialect, - copy=copy, - into_arg="expression", - **opts, - ) - - def order_by( - self: Q, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Q: - """ - Set the ORDER BY expression. - - Example: - >>> Select().from_("tbl").select("x").order_by("x DESC").sql() - 'SELECT x FROM tbl ORDER BY x DESC' - - Args: - *expressions: the SQL code strings to parse. - If a `Group` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Order`. - append: if `True`, add to any existing expressions. - Otherwise, this flattens all the `Order` expression into a single expression. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - return _apply_child_list_builder( - *expressions, - instance=self, - arg="order", - append=append, - copy=copy, - prefix="ORDER BY", - into=Order, - dialect=dialect, - **opts, - ) - - @property - def ctes(self) -> t.List[CTE]: - """Returns a list of all the CTEs attached to this query.""" - with_ = self.args.get("with") - return with_.expressions if with_ else [] - - @property - def selects(self) -> t.List[Expression]: - """Returns the query's projections.""" - raise NotImplementedError("Query objects must implement `selects`") - - @property - def named_selects(self) -> t.List[str]: - """Returns the output names of the query's projections.""" - raise NotImplementedError("Query objects must implement `named_selects`") - - def select( - self: Q, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Q: - """ - Append to or set the SELECT expressions. - - Example: - >>> Select().select("x", "y").sql() - 'SELECT x, y' - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - append: if `True`, add to any existing expressions. - Otherwise, this resets the expressions. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Query expression. - """ - raise NotImplementedError("Query objects must implement `select`") - - def where( - self: Q, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Q: - """ - Append to or set the WHERE expressions. - - Examples: - >>> Select().select("x").from_("tbl").where("x = 'a' OR x < 'b'").sql() - "SELECT x FROM tbl WHERE x = 'a' OR x < 'b'" - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - Multiple expressions are combined with an AND operator. - append: if `True`, AND the new expressions to any existing expression. - Otherwise, this resets the expression. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified expression. - """ - return _apply_conjunction_builder( - *[expr.this if isinstance(expr, Where) else expr for expr in expressions], - instance=self, - arg="where", - append=append, - into=Where, - dialect=dialect, - copy=copy, - **opts, - ) - - def with_( - self: Q, - alias: ExpOrStr, - as_: ExpOrStr, - recursive: t.Optional[bool] = None, - materialized: t.Optional[bool] = None, - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - scalar: bool = False, - **opts, - ) -> Q: - """ - Append to or set the common table expressions. - - Example: - >>> Select().with_("tbl2", as_="SELECT * FROM tbl").select("x").from_("tbl2").sql() - 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2' - - Args: - alias: the SQL code string to parse as the table name. - If an `Expression` instance is passed, this is used as-is. - as_: the SQL code string to parse as the table expression. - If an `Expression` instance is passed, it will be used as-is. - recursive: set the RECURSIVE part of the expression. Defaults to `False`. - materialized: set the MATERIALIZED part of the expression. - append: if `True`, add to any existing expressions. - Otherwise, this resets the expressions. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - scalar: if `True`, this is a scalar common table expression. - opts: other options to use to parse the input expressions. - - Returns: - The modified expression. - """ - return _apply_cte_builder( - self, - alias, - as_, - recursive=recursive, - materialized=materialized, - append=append, - dialect=dialect, - copy=copy, - scalar=scalar, - **opts, - ) - - def union( - self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts - ) -> Union: - """ - Builds a UNION expression. - - Example: - >>> import sqlglot - >>> sqlglot.parse_one("SELECT * FROM foo").union("SELECT * FROM bla").sql() - 'SELECT * FROM foo UNION SELECT * FROM bla' - - Args: - expressions: the SQL code strings. - If `Expression` instances are passed, they will be used as-is. - distinct: set the DISTINCT flag if and only if this is true. - dialect: the dialect used to parse the input expression. - opts: other options to use to parse the input expressions. - - Returns: - The new Union expression. - """ - return union(self, *expressions, distinct=distinct, dialect=dialect, **opts) - - def intersect( - self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts - ) -> Intersect: - """ - Builds an INTERSECT expression. - - Example: - >>> import sqlglot - >>> sqlglot.parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla").sql() - 'SELECT * FROM foo INTERSECT SELECT * FROM bla' - - Args: - expressions: the SQL code strings. - If `Expression` instances are passed, they will be used as-is. - distinct: set the DISTINCT flag if and only if this is true. - dialect: the dialect used to parse the input expression. - opts: other options to use to parse the input expressions. - - Returns: - The new Intersect expression. - """ - return intersect(self, *expressions, distinct=distinct, dialect=dialect, **opts) - - def except_( - self, *expressions: ExpOrStr, distinct: bool = True, dialect: DialectType = None, **opts - ) -> Except: - """ - Builds an EXCEPT expression. - - Example: - >>> import sqlglot - >>> sqlglot.parse_one("SELECT * FROM foo").except_("SELECT * FROM bla").sql() - 'SELECT * FROM foo EXCEPT SELECT * FROM bla' - - Args: - expressions: the SQL code strings. - If `Expression` instance are passed, they will be used as-is. - distinct: set the DISTINCT flag if and only if this is true. - dialect: the dialect used to parse the input expression. - opts: other options to use to parse the input expressions. - - Returns: - The new Except expression. - """ - return except_(self, *expressions, distinct=distinct, dialect=dialect, **opts) - - -class UDTF(DerivedTable): - @property - def selects(self) -> t.List[Expression]: - alias = self.args.get("alias") - return alias.columns if alias else [] - - -class Cache(Expression): - arg_types = { - "this": True, - "lazy": False, - "options": False, - "expression": False, - } - - -class Uncache(Expression): - arg_types = {"this": True, "exists": False} - - -class Refresh(Expression): - pass - - -class DDL(Expression): - @property - def ctes(self) -> t.List[CTE]: - """Returns a list of all the CTEs attached to this statement.""" - with_ = self.args.get("with") - return with_.expressions if with_ else [] - - @property - def selects(self) -> t.List[Expression]: - """If this statement contains a query (e.g. a CTAS), this returns the query's projections.""" - return self.expression.selects if isinstance(self.expression, Query) else [] - - @property - def named_selects(self) -> t.List[str]: - """ - If this statement contains a query (e.g. a CTAS), this returns the output - names of the query's projections. - """ - return self.expression.named_selects if isinstance(self.expression, Query) else [] - - -class DML(Expression): - def returning( - self, - expression: ExpOrStr, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> "Self": - """ - Set the RETURNING expression. Not supported by all dialects. - - Example: - >>> delete("tbl").returning("*", dialect="postgres").sql() - 'DELETE FROM tbl RETURNING *' - - Args: - expression: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - Delete: the modified expression. - """ - return _apply_builder( - expression=expression, - instance=self, - arg="returning", - prefix="RETURNING", - dialect=dialect, - copy=copy, - into=Returning, - **opts, - ) - - -class Create(DDL): - arg_types = { - "with": False, - "this": True, - "kind": True, - "expression": False, - "exists": False, - "properties": False, - "replace": False, - "refresh": False, - "unique": False, - "indexes": False, - "no_schema_binding": False, - "begin": False, - "end": False, - "clone": False, - "concurrently": False, - "clustered": False, - } - - @property - def kind(self) -> t.Optional[str]: - kind = self.args.get("kind") - return kind and kind.upper() - - -class SequenceProperties(Expression): - arg_types = { - "increment": False, - "minvalue": False, - "maxvalue": False, - "cache": False, - "start": False, - "owned": False, - "options": False, - } - - -class TruncateTable(Expression): - arg_types = { - "expressions": True, - "is_database": False, - "exists": False, - "only": False, - "cluster": False, - "identity": False, - "option": False, - "partition": False, - } - - -# https://docs.snowflake.com/en/sql-reference/sql/create-clone -# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_clone_statement -# https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_table_copy -class Clone(Expression): - arg_types = {"this": True, "shallow": False, "copy": False} - - -class Describe(Expression): - arg_types = { - "this": True, - "style": False, - "kind": False, - "expressions": False, - "partition": False, - "format": False, - } - - -# https://duckdb.org/docs/sql/statements/attach.html#attach -class Attach(Expression): - arg_types = {"this": True, "exists": False, "expressions": False} - - -# https://duckdb.org/docs/sql/statements/attach.html#detach -class Detach(Expression): - arg_types = {"this": True, "exists": False} - - -# https://duckdb.org/docs/guides/meta/summarize.html -class Summarize(Expression): - arg_types = {"this": True, "table": False} - - -class Kill(Expression): - arg_types = {"this": True, "kind": False} - - -class Pragma(Expression): - pass - - -class Declare(Expression): - arg_types = {"expressions": True} - - -class DeclareItem(Expression): - arg_types = {"this": True, "kind": True, "default": False} - - -class Set(Expression): - arg_types = {"expressions": False, "unset": False, "tag": False} - - -class Heredoc(Expression): - arg_types = {"this": True, "tag": False} - - -class SetItem(Expression): - arg_types = { - "this": False, - "expressions": False, - "kind": False, - "collate": False, # MySQL SET NAMES statement - "global": False, - } - - -class Show(Expression): - arg_types = { - "this": True, - "history": False, - "terse": False, - "target": False, - "offset": False, - "starts_with": False, - "limit": False, - "from": False, - "like": False, - "where": False, - "db": False, - "scope": False, - "scope_kind": False, - "full": False, - "mutex": False, - "query": False, - "channel": False, - "global": False, - "log": False, - "position": False, - "types": False, - "privileges": False, - } - - -class UserDefinedFunction(Expression): - arg_types = {"this": True, "expressions": False, "wrapped": False} - - -class CharacterSet(Expression): - arg_types = {"this": True, "default": False} - - -class RecursiveWithSearch(Expression): - arg_types = {"kind": True, "this": True, "expression": True, "using": False} - - -class With(Expression): - arg_types = {"expressions": True, "recursive": False, "search": False} - - @property - def recursive(self) -> bool: - return bool(self.args.get("recursive")) - - -class WithinGroup(Expression): - arg_types = {"this": True, "expression": False} - - -# clickhouse supports scalar ctes -# https://clickhouse.com/docs/en/sql-reference/statements/select/with -class CTE(DerivedTable): - arg_types = { - "this": True, - "alias": True, - "scalar": False, - "materialized": False, - } - - -class ProjectionDef(Expression): - arg_types = {"this": True, "expression": True} - - -class TableAlias(Expression): - arg_types = {"this": False, "columns": False} - - @property - def columns(self): - return self.args.get("columns") or [] - - -class BitString(Condition): - pass - - -class HexString(Condition): - arg_types = {"this": True, "is_integer": False} - - -class ByteString(Condition): - pass - - -class RawString(Condition): - pass - - -class UnicodeString(Condition): - arg_types = {"this": True, "escape": False} - - -class Column(Condition): - arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False} - - @property - def table(self) -> str: - return self.text("table") - - @property - def db(self) -> str: - return self.text("db") - - @property - def catalog(self) -> str: - return self.text("catalog") - - @property - def output_name(self) -> str: - return self.name - - @property - def parts(self) -> t.List[Identifier]: - """Return the parts of a column in order catalog, db, table, name.""" - return [ - t.cast(Identifier, self.args[part]) - for part in ("catalog", "db", "table", "this") - if self.args.get(part) - ] - - def to_dot(self, include_dots: bool = True) -> Dot | Identifier: - """Converts the column into a dot expression.""" - parts = self.parts - parent = self.parent - - if include_dots: - while isinstance(parent, Dot): - parts.append(parent.expression) - parent = parent.parent - - return Dot.build(deepcopy(parts)) if len(parts) > 1 else parts[0] - - -class ColumnPosition(Expression): - arg_types = {"this": False, "position": True} - - -class ColumnDef(Expression): - arg_types = { - "this": True, - "kind": False, - "constraints": False, - "exists": False, - "position": False, - "default": False, - "output": False, - } - - @property - def constraints(self) -> t.List[ColumnConstraint]: - return self.args.get("constraints") or [] - - @property - def kind(self) -> t.Optional[DataType]: - return self.args.get("kind") - - -class AlterColumn(Expression): - arg_types = { - "this": True, - "dtype": False, - "collate": False, - "using": False, - "default": False, - "drop": False, - "comment": False, - "allow_null": False, - "visible": False, - } - - -# https://dev.mysql.com/doc/refman/8.0/en/invisible-indexes.html -class AlterIndex(Expression): - arg_types = {"this": True, "visible": True} - - -# https://docs.aws.amazon.com/redshift/latest/dg/r_ALTER_TABLE.html -class AlterDistStyle(Expression): - pass - - -class AlterSortKey(Expression): - arg_types = {"this": False, "expressions": False, "compound": False} - - -class AlterSet(Expression): - arg_types = { - "expressions": False, - "option": False, - "tablespace": False, - "access_method": False, - "file_format": False, - "copy_options": False, - "tag": False, - "location": False, - "serde": False, - } - - -class RenameColumn(Expression): - arg_types = {"this": True, "to": True, "exists": False} - - -class AlterRename(Expression): - pass - - -class SwapTable(Expression): - pass - - -class Comment(Expression): - arg_types = { - "this": True, - "kind": True, - "expression": True, - "exists": False, - "materialized": False, - } - - -class Comprehension(Expression): - arg_types = {"this": True, "expression": True, "iterator": True, "condition": False} - - -# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl -class MergeTreeTTLAction(Expression): - arg_types = { - "this": True, - "delete": False, - "recompress": False, - "to_disk": False, - "to_volume": False, - } - - -# https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl -class MergeTreeTTL(Expression): - arg_types = { - "expressions": True, - "where": False, - "group": False, - "aggregates": False, - } - - -# https://dev.mysql.com/doc/refman/8.0/en/create-table.html -class IndexConstraintOption(Expression): - arg_types = { - "key_block_size": False, - "using": False, - "parser": False, - "comment": False, - "visible": False, - "engine_attr": False, - "secondary_engine_attr": False, - } - - -class ColumnConstraint(Expression): - arg_types = {"this": False, "kind": True} - - @property - def kind(self) -> ColumnConstraintKind: - return self.args["kind"] - - -class ColumnConstraintKind(Expression): - pass - - -class AutoIncrementColumnConstraint(ColumnConstraintKind): - pass - - -class PeriodForSystemTimeConstraint(ColumnConstraintKind): - arg_types = {"this": True, "expression": True} - - -class CaseSpecificColumnConstraint(ColumnConstraintKind): - arg_types = {"not_": True} - - -class CharacterSetColumnConstraint(ColumnConstraintKind): - arg_types = {"this": True} - - -class CheckColumnConstraint(ColumnConstraintKind): - arg_types = {"this": True, "enforced": False} - - -class ClusteredColumnConstraint(ColumnConstraintKind): - pass - - -class CollateColumnConstraint(ColumnConstraintKind): - pass - - -class CommentColumnConstraint(ColumnConstraintKind): - pass - - -class CompressColumnConstraint(ColumnConstraintKind): - arg_types = {"this": False} - - -class DateFormatColumnConstraint(ColumnConstraintKind): - arg_types = {"this": True} - - -class DefaultColumnConstraint(ColumnConstraintKind): - pass - - -class EncodeColumnConstraint(ColumnConstraintKind): - pass - - -# https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-EXCLUDE -class ExcludeColumnConstraint(ColumnConstraintKind): - pass - - -class EphemeralColumnConstraint(ColumnConstraintKind): - arg_types = {"this": False} - - -class WithOperator(Expression): - arg_types = {"this": True, "op": True} - - -class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind): - # this: True -> ALWAYS, this: False -> BY DEFAULT - arg_types = { - "this": False, - "expression": False, - "on_null": False, - "start": False, - "increment": False, - "minvalue": False, - "maxvalue": False, - "cycle": False, - } - - -class GeneratedAsRowColumnConstraint(ColumnConstraintKind): - arg_types = {"start": False, "hidden": False} - - -# https://dev.mysql.com/doc/refman/8.0/en/create-table.html -# https://github.com/ClickHouse/ClickHouse/blob/master/src/Parsers/ParserCreateQuery.h#L646 -class IndexColumnConstraint(ColumnConstraintKind): - arg_types = { - "this": False, - "expressions": False, - "kind": False, - "index_type": False, - "options": False, - "expression": False, # Clickhouse - "granularity": False, - } - - -class InlineLengthColumnConstraint(ColumnConstraintKind): - pass - - -class NonClusteredColumnConstraint(ColumnConstraintKind): - pass - - -class NotForReplicationColumnConstraint(ColumnConstraintKind): - arg_types = {} - - -# https://docs.snowflake.com/en/sql-reference/sql/create-table -class MaskingPolicyColumnConstraint(ColumnConstraintKind): - arg_types = {"this": True, "expressions": False} - - -class NotNullColumnConstraint(ColumnConstraintKind): - arg_types = {"allow_null": False} - - -# https://dev.mysql.com/doc/refman/5.7/en/timestamp-initialization.html -class OnUpdateColumnConstraint(ColumnConstraintKind): - pass - - -class PrimaryKeyColumnConstraint(ColumnConstraintKind): - arg_types = {"desc": False, "options": False} - - -class TitleColumnConstraint(ColumnConstraintKind): - pass - - -class UniqueColumnConstraint(ColumnConstraintKind): - arg_types = { - "this": False, - "index_type": False, - "on_conflict": False, - "nulls": False, - "options": False, - } - - -class UppercaseColumnConstraint(ColumnConstraintKind): - arg_types: t.Dict[str, t.Any] = {} - - -# https://docs.risingwave.com/processing/watermarks#syntax -class WatermarkColumnConstraint(Expression): - arg_types = {"this": True, "expression": True} - - -class PathColumnConstraint(ColumnConstraintKind): - pass - - -# https://docs.snowflake.com/en/sql-reference/sql/create-table -class ProjectionPolicyColumnConstraint(ColumnConstraintKind): - pass - - -# computed column expression -# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-table-transact-sql?view=sql-server-ver16 -class ComputedColumnConstraint(ColumnConstraintKind): - arg_types = {"this": True, "persisted": False, "not_null": False} - - -class Constraint(Expression): - arg_types = {"this": True, "expressions": True} - - -class Delete(DML): - arg_types = { - "with": False, - "this": False, - "using": False, - "where": False, - "returning": False, - "limit": False, - "tables": False, # Multiple-Table Syntax (MySQL) - "cluster": False, # Clickhouse - } - - def delete( - self, - table: ExpOrStr, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Delete: - """ - Create a DELETE expression or replace the table on an existing DELETE expression. - - Example: - >>> delete("tbl").sql() - 'DELETE FROM tbl' - - Args: - table: the table from which to delete. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - Delete: the modified expression. - """ - return _apply_builder( - expression=table, - instance=self, - arg="this", - dialect=dialect, - into=Table, - copy=copy, - **opts, - ) - - def where( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Delete: - """ - Append to or set the WHERE expressions. - - Example: - >>> delete("tbl").where("x = 'a' OR x < 'b'").sql() - "DELETE FROM tbl WHERE x = 'a' OR x < 'b'" - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - Multiple expressions are combined with an AND operator. - append: if `True`, AND the new expressions to any existing expression. - Otherwise, this resets the expression. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - Delete: the modified expression. - """ - return _apply_conjunction_builder( - *expressions, - instance=self, - arg="where", - append=append, - into=Where, - dialect=dialect, - copy=copy, - **opts, - ) - - -class Drop(Expression): - arg_types = { - "this": False, - "kind": False, - "expressions": False, - "exists": False, - "temporary": False, - "materialized": False, - "cascade": False, - "constraints": False, - "purge": False, - "cluster": False, - "concurrently": False, - } - - @property - def kind(self) -> t.Optional[str]: - kind = self.args.get("kind") - return kind and kind.upper() - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/export-statements -class Export(Expression): - arg_types = {"this": True, "connection": False, "options": True} - - -class Filter(Expression): - arg_types = {"this": True, "expression": True} - - -class Check(Expression): - pass - - -class Changes(Expression): - arg_types = {"information": True, "at_before": False, "end": False} - - -# https://docs.snowflake.com/en/sql-reference/constructs/connect-by -class Connect(Expression): - arg_types = {"start": False, "connect": True, "nocycle": False} - - -class CopyParameter(Expression): - arg_types = {"this": True, "expression": False, "expressions": False} - - -class Copy(DML): - arg_types = { - "this": True, - "kind": True, - "files": True, - "credentials": False, - "format": False, - "params": False, - } - - -class Credentials(Expression): - arg_types = { - "credentials": False, - "encryption": False, - "storage": False, - "iam_role": False, - "region": False, - } - - -class Prior(Expression): - pass - - -class Directory(Expression): - # https://spark.apache.org/docs/3.0.0-preview/sql-ref-syntax-dml-insert-overwrite-directory-hive.html - arg_types = {"this": True, "local": False, "row_format": False} - - -class ForeignKey(Expression): - arg_types = { - "expressions": False, - "reference": False, - "delete": False, - "update": False, - "options": False, - } - - -class ColumnPrefix(Expression): - arg_types = {"this": True, "expression": True} - - -class PrimaryKey(Expression): - arg_types = {"expressions": True, "options": False} - - -# https://www.postgresql.org/docs/9.1/sql-selectinto.html -# https://docs.aws.amazon.com/redshift/latest/dg/r_SELECT_INTO.html#r_SELECT_INTO-examples -class Into(Expression): - arg_types = { - "this": False, - "temporary": False, - "unlogged": False, - "bulk_collect": False, - "expressions": False, - } - - -class From(Expression): - @property - def name(self) -> str: - return self.this.name - - @property - def alias_or_name(self) -> str: - return self.this.alias_or_name - - -class Having(Expression): - pass - - -class Hint(Expression): - arg_types = {"expressions": True} - - -class JoinHint(Expression): - arg_types = {"this": True, "expressions": True} - - -class Identifier(Expression): - arg_types = {"this": True, "quoted": False, "global": False, "temporary": False} - - @property - def quoted(self) -> bool: - return bool(self.args.get("quoted")) - - @property - def hashable_args(self) -> t.Any: - return (self.this, self.quoted) - - @property - def output_name(self) -> str: - return self.name - - -# https://www.postgresql.org/docs/current/indexes-opclass.html -class Opclass(Expression): - arg_types = {"this": True, "expression": True} - - -class Index(Expression): - arg_types = { - "this": False, - "table": False, - "unique": False, - "primary": False, - "amp": False, # teradata - "params": False, - } - - -class IndexParameters(Expression): - arg_types = { - "using": False, - "include": False, - "columns": False, - "with_storage": False, - "partition_by": False, - "tablespace": False, - "where": False, - "on": False, - } - - -class Insert(DDL, DML): - arg_types = { - "hint": False, - "with": False, - "is_function": False, - "this": False, - "expression": False, - "conflict": False, - "returning": False, - "overwrite": False, - "exists": False, - "alternative": False, - "where": False, - "ignore": False, - "by_name": False, - "stored": False, - "partition": False, - "settings": False, - "source": False, - } - - def with_( - self, - alias: ExpOrStr, - as_: ExpOrStr, - recursive: t.Optional[bool] = None, - materialized: t.Optional[bool] = None, - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Insert: - """ - Append to or set the common table expressions. - - Example: - >>> insert("SELECT x FROM cte", "t").with_("cte", as_="SELECT * FROM tbl").sql() - 'WITH cte AS (SELECT * FROM tbl) INSERT INTO t SELECT x FROM cte' - - Args: - alias: the SQL code string to parse as the table name. - If an `Expression` instance is passed, this is used as-is. - as_: the SQL code string to parse as the table expression. - If an `Expression` instance is passed, it will be used as-is. - recursive: set the RECURSIVE part of the expression. Defaults to `False`. - materialized: set the MATERIALIZED part of the expression. - append: if `True`, add to any existing expressions. - Otherwise, this resets the expressions. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified expression. - """ - return _apply_cte_builder( - self, - alias, - as_, - recursive=recursive, - materialized=materialized, - append=append, - dialect=dialect, - copy=copy, - **opts, - ) - - -class ConditionalInsert(Expression): - arg_types = {"this": True, "expression": False, "else_": False} - - -class MultitableInserts(Expression): - arg_types = {"expressions": True, "kind": True, "source": True} - - -class OnConflict(Expression): - arg_types = { - "duplicate": False, - "expressions": False, - "action": False, - "conflict_keys": False, - "constraint": False, - "where": False, - } - - -class OnCondition(Expression): - arg_types = {"error": False, "empty": False, "null": False} - - -class Returning(Expression): - arg_types = {"expressions": True, "into": False} - - -# https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html -class Introducer(Expression): - arg_types = {"this": True, "expression": True} - - -# national char, like n'utf8' -class National(Expression): - pass - - -class LoadData(Expression): - arg_types = { - "this": True, - "local": False, - "overwrite": False, - "inpath": True, - "partition": False, - "input_format": False, - "serde": False, - } - - -class Partition(Expression): - arg_types = {"expressions": True, "subpartition": False} - - -class PartitionRange(Expression): - arg_types = {"this": True, "expression": True} - - -# https://clickhouse.com/docs/en/sql-reference/statements/alter/partition#how-to-set-partition-expression -class PartitionId(Expression): - pass - - -class Fetch(Expression): - arg_types = { - "direction": False, - "count": False, - "limit_options": False, - } - - -class Grant(Expression): - arg_types = { - "privileges": True, - "kind": False, - "securable": True, - "principals": True, - "grant_option": False, - } - - -class Group(Expression): - arg_types = { - "expressions": False, - "grouping_sets": False, - "cube": False, - "rollup": False, - "totals": False, - "all": False, - } - - -class Cube(Expression): - arg_types = {"expressions": False} - - -class Rollup(Expression): - arg_types = {"expressions": False} - - -class GroupingSets(Expression): - arg_types = {"expressions": True} - - -class Lambda(Expression): - arg_types = {"this": True, "expressions": True} - - -class Limit(Expression): - arg_types = { - "this": False, - "expression": True, - "offset": False, - "limit_options": False, - "expressions": False, - } - - -class LimitOptions(Expression): - arg_types = { - "percent": False, - "rows": False, - "with_ties": False, - } - - -class Literal(Condition): - arg_types = {"this": True, "is_string": True} - - @property - def hashable_args(self) -> t.Any: - return (self.this, self.args.get("is_string")) - - @classmethod - def number(cls, number) -> Literal: - return cls(this=str(number), is_string=False) - - @classmethod - def string(cls, string) -> Literal: - return cls(this=str(string), is_string=True) - - @property - def output_name(self) -> str: - return self.name - - def to_py(self) -> int | str | Decimal: - if self.is_number: - try: - return int(self.this) - except ValueError: - return Decimal(self.this) - return self.this - - -class Join(Expression): - arg_types = { - "this": True, - "on": False, - "side": False, - "kind": False, - "using": False, - "method": False, - "global": False, - "hint": False, - "match_condition": False, # Snowflake - "expressions": False, - "pivots": False, - } - - @property - def method(self) -> str: - return self.text("method").upper() - - @property - def kind(self) -> str: - return self.text("kind").upper() - - @property - def side(self) -> str: - return self.text("side").upper() - - @property - def hint(self) -> str: - return self.text("hint").upper() - - @property - def alias_or_name(self) -> str: - return self.this.alias_or_name - - @property - def is_semi_or_anti_join(self) -> bool: - return self.kind in ("SEMI", "ANTI") - - def on( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Join: - """ - Append to or set the ON expressions. - - Example: - >>> import sqlglot - >>> sqlglot.parse_one("JOIN x", into=Join).on("y = 1").sql() - 'JOIN x ON y = 1' - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - Multiple expressions are combined with an AND operator. - append: if `True`, AND the new expressions to any existing expression. - Otherwise, this resets the expression. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Join expression. - """ - join = _apply_conjunction_builder( - *expressions, - instance=self, - arg="on", - append=append, - dialect=dialect, - copy=copy, - **opts, - ) - - if join.kind == "CROSS": - join.set("kind", None) - - return join - - def using( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Join: - """ - Append to or set the USING expressions. - - Example: - >>> import sqlglot - >>> sqlglot.parse_one("JOIN x", into=Join).using("foo", "bla").sql() - 'JOIN x USING (foo, bla)' - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - append: if `True`, concatenate the new expressions to the existing "using" list. - Otherwise, this resets the expression. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Join expression. - """ - join = _apply_list_builder( - *expressions, - instance=self, - arg="using", - append=append, - dialect=dialect, - copy=copy, - **opts, - ) - - if join.kind == "CROSS": - join.set("kind", None) - - return join - - -class Lateral(UDTF): - arg_types = { - "this": True, - "view": False, - "outer": False, - "alias": False, - "cross_apply": False, # True -> CROSS APPLY, False -> OUTER APPLY - "ordinality": False, - } - - -# https://docs.snowflake.com/sql-reference/literals-table -# https://docs.snowflake.com/en/sql-reference/functions-table#using-a-table-function -class TableFromRows(UDTF): - arg_types = { - "this": True, - "alias": False, - "joins": False, - "pivots": False, - "sample": False, - } - - -class MatchRecognizeMeasure(Expression): - arg_types = { - "this": True, - "window_frame": False, - } - - -class MatchRecognize(Expression): - arg_types = { - "partition_by": False, - "order": False, - "measures": False, - "rows": False, - "after": False, - "pattern": False, - "define": False, - "alias": False, - } - - -# Clickhouse FROM FINAL modifier -# https://clickhouse.com/docs/en/sql-reference/statements/select/from/#final-modifier -class Final(Expression): - pass - - -class Offset(Expression): - arg_types = {"this": False, "expression": True, "expressions": False} - - -class Order(Expression): - arg_types = {"this": False, "expressions": True, "siblings": False} - - -# https://clickhouse.com/docs/en/sql-reference/statements/select/order-by#order-by-expr-with-fill-modifier -class WithFill(Expression): - arg_types = { - "from": False, - "to": False, - "step": False, - "interpolate": False, - } - - -# hive specific sorts -# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+SortBy -class Cluster(Order): - pass - - -class Distribute(Order): - pass - - -class Sort(Order): - pass - - -class Ordered(Expression): - arg_types = {"this": True, "desc": False, "nulls_first": True, "with_fill": False} - - @property - def name(self) -> str: - return self.this.name - - -class Property(Expression): - arg_types = {"this": True, "value": True} - - -class GrantPrivilege(Expression): - arg_types = {"this": True, "expressions": False} - - -class GrantPrincipal(Expression): - arg_types = {"this": True, "kind": False} - - -class AllowedValuesProperty(Expression): - arg_types = {"expressions": True} - - -class AlgorithmProperty(Property): - arg_types = {"this": True} - - -class AutoIncrementProperty(Property): - arg_types = {"this": True} - - -# https://docs.aws.amazon.com/prescriptive-guidance/latest/materialized-views-redshift/refreshing-materialized-views.html -class AutoRefreshProperty(Property): - arg_types = {"this": True} - - -class BackupProperty(Property): - arg_types = {"this": True} - - -class BlockCompressionProperty(Property): - arg_types = { - "autotemp": False, - "always": False, - "default": False, - "manual": False, - "never": False, - } - - -class CharacterSetProperty(Property): - arg_types = {"this": True, "default": True} - - -class ChecksumProperty(Property): - arg_types = {"on": False, "default": False} - - -class CollateProperty(Property): - arg_types = {"this": True, "default": False} - - -class CopyGrantsProperty(Property): - arg_types = {} - - -class DataBlocksizeProperty(Property): - arg_types = { - "size": False, - "units": False, - "minimum": False, - "maximum": False, - "default": False, - } - - -class DataDeletionProperty(Property): - arg_types = {"on": True, "filter_col": False, "retention_period": False} - - -class DefinerProperty(Property): - arg_types = {"this": True} - - -class DistKeyProperty(Property): - arg_types = {"this": True} - - -# https://docs.starrocks.io/docs/sql-reference/sql-statements/data-definition/CREATE_TABLE/#distribution_desc -# https://doris.apache.org/docs/sql-manual/sql-statements/Data-Definition-Statements/Create/CREATE-TABLE?_highlight=create&_highlight=table#distribution_desc -class DistributedByProperty(Property): - arg_types = {"expressions": False, "kind": True, "buckets": False, "order": False} - - -class DistStyleProperty(Property): - arg_types = {"this": True} - - -class DuplicateKeyProperty(Property): - arg_types = {"expressions": True} - - -class EngineProperty(Property): - arg_types = {"this": True} - - -class HeapProperty(Property): - arg_types = {} - - -class ToTableProperty(Property): - arg_types = {"this": True} - - -class ExecuteAsProperty(Property): - arg_types = {"this": True} - - -class ExternalProperty(Property): - arg_types = {"this": False} - - -class FallbackProperty(Property): - arg_types = {"no": True, "protection": False} - - -class FileFormatProperty(Property): - arg_types = {"this": False, "expressions": False} - - -class CredentialsProperty(Property): - arg_types = {"expressions": True} - - -class FreespaceProperty(Property): - arg_types = {"this": True, "percent": False} - - -class GlobalProperty(Property): - arg_types = {} - - -class IcebergProperty(Property): - arg_types = {} - - -class InheritsProperty(Property): - arg_types = {"expressions": True} - - -class InputModelProperty(Property): - arg_types = {"this": True} - - -class OutputModelProperty(Property): - arg_types = {"this": True} - - -class IsolatedLoadingProperty(Property): - arg_types = {"no": False, "concurrent": False, "target": False} - - -class JournalProperty(Property): - arg_types = { - "no": False, - "dual": False, - "before": False, - "local": False, - "after": False, - } - - -class LanguageProperty(Property): - arg_types = {"this": True} - - -class EnviromentProperty(Property): - arg_types = {"expressions": True} - - -# spark ddl -class ClusteredByProperty(Property): - arg_types = {"expressions": True, "sorted_by": False, "buckets": True} - - -class DictProperty(Property): - arg_types = {"this": True, "kind": True, "settings": False} - - -class DictSubProperty(Property): - pass - - -class DictRange(Property): - arg_types = {"this": True, "min": True, "max": True} - - -class DynamicProperty(Property): - arg_types = {} - - -# Clickhouse CREATE ... ON CLUSTER modifier -# https://clickhouse.com/docs/en/sql-reference/distributed-ddl -class OnCluster(Property): - arg_types = {"this": True} - - -# Clickhouse EMPTY table "property" -class EmptyProperty(Property): - arg_types = {} - - -class LikeProperty(Property): - arg_types = {"this": True, "expressions": False} - - -class LocationProperty(Property): - arg_types = {"this": True} - - -class LockProperty(Property): - arg_types = {"this": True} - - -class LockingProperty(Property): - arg_types = { - "this": False, - "kind": True, - "for_or_in": False, - "lock_type": True, - "override": False, - } - - -class LogProperty(Property): - arg_types = {"no": True} - - -class MaterializedProperty(Property): - arg_types = {"this": False} - - -class MergeBlockRatioProperty(Property): - arg_types = {"this": False, "no": False, "default": False, "percent": False} - - -class NoPrimaryIndexProperty(Property): - arg_types = {} - - -class OnProperty(Property): - arg_types = {"this": True} - - -class OnCommitProperty(Property): - arg_types = {"delete": False} - - -class PartitionedByProperty(Property): - arg_types = {"this": True} - - -class PartitionedByBucket(Property): - arg_types = {"this": True, "expression": True} - - -class PartitionByTruncate(Property): - arg_types = {"this": True, "expression": True} - - -# https://docs.starrocks.io/docs/sql-reference/sql-statements/table_bucket_part_index/CREATE_TABLE/ -class PartitionByRangeProperty(Property): - arg_types = {"partition_expressions": True, "create_expressions": True} - - -# https://docs.starrocks.io/docs/table_design/data_distribution/#range-partitioning -class PartitionByRangePropertyDynamic(Expression): - arg_types = {"this": False, "start": True, "end": True, "every": True} - - -# https://docs.starrocks.io/docs/sql-reference/sql-statements/table_bucket_part_index/CREATE_TABLE/ -class UniqueKeyProperty(Property): - arg_types = {"expressions": True} - - -# https://www.postgresql.org/docs/current/sql-createtable.html -class PartitionBoundSpec(Expression): - # this -> IN / MODULUS, expression -> REMAINDER, from_expressions -> FROM (...), to_expressions -> TO (...) - arg_types = { - "this": False, - "expression": False, - "from_expressions": False, - "to_expressions": False, - } - - -class PartitionedOfProperty(Property): - # this -> parent_table (schema), expression -> FOR VALUES ... / DEFAULT - arg_types = {"this": True, "expression": True} - - -class StreamingTableProperty(Property): - arg_types = {} - - -class RemoteWithConnectionModelProperty(Property): - arg_types = {"this": True} - - -class ReturnsProperty(Property): - arg_types = {"this": False, "is_table": False, "table": False, "null": False} - - -class StrictProperty(Property): - arg_types = {} - - -class RowFormatProperty(Property): - arg_types = {"this": True} - - -class RowFormatDelimitedProperty(Property): - # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml - arg_types = { - "fields": False, - "escaped": False, - "collection_items": False, - "map_keys": False, - "lines": False, - "null": False, - "serde": False, - } - - -class RowFormatSerdeProperty(Property): - arg_types = {"this": True, "serde_properties": False} - - -# https://spark.apache.org/docs/3.1.2/sql-ref-syntax-qry-select-transform.html -class QueryTransform(Expression): - arg_types = { - "expressions": True, - "command_script": True, - "schema": False, - "row_format_before": False, - "record_writer": False, - "row_format_after": False, - "record_reader": False, - } - - -class SampleProperty(Property): - arg_types = {"this": True} - - -# https://prestodb.io/docs/current/sql/create-view.html#synopsis -class SecurityProperty(Property): - arg_types = {"this": True} - - -class SchemaCommentProperty(Property): - arg_types = {"this": True} - - -class SerdeProperties(Property): - arg_types = {"expressions": True, "with": False} - - -class SetProperty(Property): - arg_types = {"multi": True} - - -class SharingProperty(Property): - arg_types = {"this": False} - - -class SetConfigProperty(Property): - arg_types = {"this": True} - - -class SettingsProperty(Property): - arg_types = {"expressions": True} - - -class SortKeyProperty(Property): - arg_types = {"this": True, "compound": False} - - -class SqlReadWriteProperty(Property): - arg_types = {"this": True} - - -class SqlSecurityProperty(Property): - arg_types = {"definer": True} - - -class StabilityProperty(Property): - arg_types = {"this": True} - - -class StorageHandlerProperty(Property): - arg_types = {"this": True} - - -class TemporaryProperty(Property): - arg_types = {"this": False} - - -class SecureProperty(Property): - arg_types = {} - - -# https://docs.snowflake.com/en/sql-reference/sql/create-table -class Tags(ColumnConstraintKind, Property): - arg_types = {"expressions": True} - - -class TransformModelProperty(Property): - arg_types = {"expressions": True} - - -class TransientProperty(Property): - arg_types = {"this": False} - - -class UnloggedProperty(Property): - arg_types = {} - - -# https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-using-template -class UsingTemplateProperty(Property): - arg_types = {"this": True} - - -# https://learn.microsoft.com/en-us/sql/t-sql/statements/create-view-transact-sql?view=sql-server-ver16 -class ViewAttributeProperty(Property): - arg_types = {"this": True} - - -class VolatileProperty(Property): - arg_types = {"this": False} - - -class WithDataProperty(Property): - arg_types = {"no": True, "statistics": False} - - -class WithJournalTableProperty(Property): - arg_types = {"this": True} - - -class WithSchemaBindingProperty(Property): - arg_types = {"this": True} - - -class WithSystemVersioningProperty(Property): - arg_types = { - "on": False, - "this": False, - "data_consistency": False, - "retention_period": False, - "with": True, - } - - -class WithProcedureOptions(Property): - arg_types = {"expressions": True} - - -class EncodeProperty(Property): - arg_types = {"this": True, "properties": False, "key": False} - - -class IncludeProperty(Property): - arg_types = {"this": True, "alias": False, "column_def": False} - - -class ForceProperty(Property): - arg_types = {} - - -class Properties(Expression): - arg_types = {"expressions": True} - - NAME_TO_PROPERTY = { - "ALGORITHM": AlgorithmProperty, - "AUTO_INCREMENT": AutoIncrementProperty, - "CHARACTER SET": CharacterSetProperty, - "CLUSTERED_BY": ClusteredByProperty, - "COLLATE": CollateProperty, - "COMMENT": SchemaCommentProperty, - "CREDENTIALS": CredentialsProperty, - "DEFINER": DefinerProperty, - "DISTKEY": DistKeyProperty, - "DISTRIBUTED_BY": DistributedByProperty, - "DISTSTYLE": DistStyleProperty, - "ENGINE": EngineProperty, - "EXECUTE AS": ExecuteAsProperty, - "FORMAT": FileFormatProperty, - "LANGUAGE": LanguageProperty, - "LOCATION": LocationProperty, - "LOCK": LockProperty, - "PARTITIONED_BY": PartitionedByProperty, - "RETURNS": ReturnsProperty, - "ROW_FORMAT": RowFormatProperty, - "SORTKEY": SortKeyProperty, - "ENCODE": EncodeProperty, - "INCLUDE": IncludeProperty, - } - - PROPERTY_TO_NAME = {v: k for k, v in NAME_TO_PROPERTY.items()} - - # CREATE property locations - # Form: schema specified - # create [POST_CREATE] - # table a [POST_NAME] - # (b int) [POST_SCHEMA] - # with ([POST_WITH]) - # index (b) [POST_INDEX] - # - # Form: alias selection - # create [POST_CREATE] - # table a [POST_NAME] - # as [POST_ALIAS] (select * from b) [POST_EXPRESSION] - # index (c) [POST_INDEX] - class Location(AutoName): - POST_CREATE = auto() - POST_NAME = auto() - POST_SCHEMA = auto() - POST_WITH = auto() - POST_ALIAS = auto() - POST_EXPRESSION = auto() - POST_INDEX = auto() - UNSUPPORTED = auto() - - @classmethod - def from_dict(cls, properties_dict: t.Dict) -> Properties: - expressions = [] - for key, value in properties_dict.items(): - property_cls = cls.NAME_TO_PROPERTY.get(key.upper()) - if property_cls: - expressions.append(property_cls(this=convert(value))) - else: - expressions.append(Property(this=Literal.string(key), value=convert(value))) - - return cls(expressions=expressions) - - -class Qualify(Expression): - pass - - -class InputOutputFormat(Expression): - arg_types = {"input_format": False, "output_format": False} - - -# https://www.ibm.com/docs/en/ias?topic=procedures-return-statement-in-sql -class Return(Expression): - pass - - -class Reference(Expression): - arg_types = {"this": True, "expressions": False, "options": False} - - -class Tuple(Expression): - arg_types = {"expressions": False} - - def isin( - self, - *expressions: t.Any, - query: t.Optional[ExpOrStr] = None, - unnest: t.Optional[ExpOrStr] | t.Collection[ExpOrStr] = None, - copy: bool = True, - **opts, - ) -> In: - return In( - this=maybe_copy(self, copy), - expressions=[convert(e, copy=copy) for e in expressions], - query=maybe_parse(query, copy=copy, **opts) if query else None, - unnest=( - Unnest( - expressions=[ - maybe_parse(t.cast(ExpOrStr, e), copy=copy, **opts) - for e in ensure_list(unnest) - ] - ) - if unnest - else None - ), - ) - - -QUERY_MODIFIERS = { - "match": False, - "laterals": False, - "joins": False, - "connect": False, - "pivots": False, - "prewhere": False, - "where": False, - "group": False, - "having": False, - "qualify": False, - "windows": False, - "distribute": False, - "sort": False, - "cluster": False, - "order": False, - "limit": False, - "offset": False, - "locks": False, - "sample": False, - "settings": False, - "format": False, - "options": False, -} - - -# https://learn.microsoft.com/en-us/sql/t-sql/queries/option-clause-transact-sql?view=sql-server-ver16 -# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-query?view=sql-server-ver16 -class QueryOption(Expression): - arg_types = {"this": True, "expression": False} - - -# https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-table?view=sql-server-ver16 -class WithTableHint(Expression): - arg_types = {"expressions": True} - - -# https://dev.mysql.com/doc/refman/8.0/en/index-hints.html -class IndexTableHint(Expression): - arg_types = {"this": True, "expressions": False, "target": False} - - -# https://docs.snowflake.com/en/sql-reference/constructs/at-before -class HistoricalData(Expression): - arg_types = {"this": True, "kind": True, "expression": True} - - -# https://docs.snowflake.com/en/sql-reference/sql/put -class Put(Expression): - arg_types = {"this": True, "target": True, "properties": False} - - -# https://docs.snowflake.com/en/sql-reference/sql/get -class Get(Expression): - arg_types = {"this": True, "target": True, "properties": False} - - -class Table(Expression): - arg_types = { - "this": False, - "alias": False, - "db": False, - "catalog": False, - "laterals": False, - "joins": False, - "pivots": False, - "hints": False, - "system_time": False, - "version": False, - "format": False, - "pattern": False, - "ordinality": False, - "when": False, - "only": False, - "partition": False, - "changes": False, - "rows_from": False, - "sample": False, - } - - @property - def name(self) -> str: - if not self.this or isinstance(self.this, Func): - return "" - return self.this.name - - @property - def db(self) -> str: - return self.text("db") - - @property - def catalog(self) -> str: - return self.text("catalog") - - @property - def selects(self) -> t.List[Expression]: - return [] - - @property - def named_selects(self) -> t.List[str]: - return [] - - @property - def parts(self) -> t.List[Expression]: - """Return the parts of a table in order catalog, db, table.""" - parts: t.List[Expression] = [] - - for arg in ("catalog", "db", "this"): - part = self.args.get(arg) - - if isinstance(part, Dot): - parts.extend(part.flatten()) - elif isinstance(part, Expression): - parts.append(part) - - return parts - - def to_column(self, copy: bool = True) -> Expression: - parts = self.parts - last_part = parts[-1] - - if isinstance(last_part, Identifier): - col: Expression = column(*reversed(parts[0:4]), fields=parts[4:], copy=copy) # type: ignore - else: - # This branch will be reached if a function or array is wrapped in a `Table` - col = last_part - - alias = self.args.get("alias") - if alias: - col = alias_(col, alias.this, copy=copy) - - return col - - -class SetOperation(Query): - arg_types = { - "with": False, - "this": True, - "expression": True, - "distinct": False, - "by_name": False, - "side": False, - "kind": False, - "on": False, - **QUERY_MODIFIERS, - } - - def select( - self: S, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> S: - this = maybe_copy(self, copy) - this.this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts) - this.expression.unnest().select( - *expressions, append=append, dialect=dialect, copy=False, **opts - ) - return this - - @property - def named_selects(self) -> t.List[str]: - return self.this.unnest().named_selects - - @property - def is_star(self) -> bool: - return self.this.is_star or self.expression.is_star - - @property - def selects(self) -> t.List[Expression]: - return self.this.unnest().selects - - @property - def left(self) -> Query: - return self.this - - @property - def right(self) -> Query: - return self.expression - - @property - def kind(self) -> str: - return self.text("kind").upper() - - @property - def side(self) -> str: - return self.text("side").upper() - - -class Union(SetOperation): - pass - - -class Except(SetOperation): - pass - - -class Intersect(SetOperation): - pass - - -class Update(DML): - arg_types = { - "with": False, - "this": False, - "expressions": True, - "from": False, - "where": False, - "returning": False, - "order": False, - "limit": False, - } - - def table( - self, expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts - ) -> Update: - """ - Set the table to update. - - Example: - >>> Update().table("my_table").set_("x = 1").sql() - 'UPDATE my_table SET x = 1' - - Args: - expression : the SQL code strings to parse. - If a `Table` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Table`. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Update expression. - """ - return _apply_builder( - expression=expression, - instance=self, - arg="this", - into=Table, - prefix=None, - dialect=dialect, - copy=copy, - **opts, - ) - - def set_( - self, - *expressions: ExpOrStr, - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Update: - """ - Append to or set the SET expressions. - - Example: - >>> Update().table("my_table").set_("x = 1").sql() - 'UPDATE my_table SET x = 1' - - Args: - *expressions: the SQL code strings to parse. - If `Expression` instance(s) are passed, they will be used as-is. - Multiple expressions are combined with a comma. - append: if `True`, add the new expressions to any existing SET expressions. - Otherwise, this resets the expressions. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - """ - return _apply_list_builder( - *expressions, - instance=self, - arg="expressions", - append=append, - into=Expression, - prefix=None, - dialect=dialect, - copy=copy, - **opts, - ) - - def where( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - """ - Append to or set the WHERE expressions. - - Example: - >>> Update().table("tbl").set_("x = 1").where("x = 'a' OR x < 'b'").sql() - "UPDATE tbl SET x = 1 WHERE x = 'a' OR x < 'b'" - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - Multiple expressions are combined with an AND operator. - append: if `True`, AND the new expressions to any existing expression. - Otherwise, this resets the expression. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - Select: the modified expression. - """ - return _apply_conjunction_builder( - *expressions, - instance=self, - arg="where", - append=append, - into=Where, - dialect=dialect, - copy=copy, - **opts, - ) - - def from_( - self, - expression: t.Optional[ExpOrStr] = None, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Update: - """ - Set the FROM expression. - - Example: - >>> Update().table("my_table").set_("x = 1").from_("baz").sql() - 'UPDATE my_table SET x = 1 FROM baz' - - Args: - expression : the SQL code strings to parse. - If a `From` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `From`. - If nothing is passed in then a from is not applied to the expression - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Update expression. - """ - if not expression: - return maybe_copy(self, copy) - - return _apply_builder( - expression=expression, - instance=self, - arg="from", - into=From, - prefix="FROM", - dialect=dialect, - copy=copy, - **opts, - ) - - def with_( - self, - alias: ExpOrStr, - as_: ExpOrStr, - recursive: t.Optional[bool] = None, - materialized: t.Optional[bool] = None, - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Update: - """ - Append to or set the common table expressions. - - Example: - >>> Update().table("my_table").set_("x = 1").from_("baz").with_("baz", "SELECT id FROM foo").sql() - 'WITH baz AS (SELECT id FROM foo) UPDATE my_table SET x = 1 FROM baz' - - Args: - alias: the SQL code string to parse as the table name. - If an `Expression` instance is passed, this is used as-is. - as_: the SQL code string to parse as the table expression. - If an `Expression` instance is passed, it will be used as-is. - recursive: set the RECURSIVE part of the expression. Defaults to `False`. - materialized: set the MATERIALIZED part of the expression. - append: if `True`, add to any existing expressions. - Otherwise, this resets the expressions. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified expression. - """ - return _apply_cte_builder( - self, - alias, - as_, - recursive=recursive, - materialized=materialized, - append=append, - dialect=dialect, - copy=copy, - **opts, - ) - - -class Values(UDTF): - arg_types = {"expressions": True, "alias": False} - - -class Var(Expression): - pass - - -class Version(Expression): - """ - Time travel, iceberg, bigquery etc - https://trino.io/docs/current/connector/iceberg.html?highlight=snapshot#using-snapshots - https://www.databricks.com/blog/2019/02/04/introducing-delta-time-travel-for-large-scale-data-lakes.html - https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#for_system_time_as_of - https://learn.microsoft.com/en-us/sql/relational-databases/tables/querying-data-in-a-system-versioned-temporal-table?view=sql-server-ver16 - this is either TIMESTAMP or VERSION - kind is ("AS OF", "BETWEEN") - """ - - arg_types = {"this": True, "kind": True, "expression": False} - - -class Schema(Expression): - arg_types = {"this": False, "expressions": False} - - -# https://dev.mysql.com/doc/refman/8.0/en/select.html -# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/SELECT.html -class Lock(Expression): - arg_types = {"update": True, "expressions": False, "wait": False} - - -class Select(Query): - arg_types = { - "with": False, - "kind": False, - "expressions": False, - "hint": False, - "distinct": False, - "into": False, - "from": False, - "operation_modifiers": False, - **QUERY_MODIFIERS, - } - - def from_( - self, expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts - ) -> Select: - """ - Set the FROM expression. - - Example: - >>> Select().from_("tbl").select("x").sql() - 'SELECT x FROM tbl' - - Args: - expression : the SQL code strings to parse. - If a `From` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `From`. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - return _apply_builder( - expression=expression, - instance=self, - arg="from", - into=From, - prefix="FROM", - dialect=dialect, - copy=copy, - **opts, - ) - - def group_by( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - """ - Set the GROUP BY expression. - - Example: - >>> Select().from_("tbl").select("x", "COUNT(1)").group_by("x").sql() - 'SELECT x, COUNT(1) FROM tbl GROUP BY x' - - Args: - *expressions: the SQL code strings to parse. - If a `Group` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Group`. - If nothing is passed in then a group by is not applied to the expression - append: if `True`, add to any existing expressions. - Otherwise, this flattens all the `Group` expression into a single expression. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - if not expressions: - return self if not copy else self.copy() - - return _apply_child_list_builder( - *expressions, - instance=self, - arg="group", - append=append, - copy=copy, - prefix="GROUP BY", - into=Group, - dialect=dialect, - **opts, - ) - - def sort_by( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - """ - Set the SORT BY expression. - - Example: - >>> Select().from_("tbl").select("x").sort_by("x DESC").sql(dialect="hive") - 'SELECT x FROM tbl SORT BY x DESC' - - Args: - *expressions: the SQL code strings to parse. - If a `Group` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `SORT`. - append: if `True`, add to any existing expressions. - Otherwise, this flattens all the `Order` expression into a single expression. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - return _apply_child_list_builder( - *expressions, - instance=self, - arg="sort", - append=append, - copy=copy, - prefix="SORT BY", - into=Sort, - dialect=dialect, - **opts, - ) - - def cluster_by( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - """ - Set the CLUSTER BY expression. - - Example: - >>> Select().from_("tbl").select("x").cluster_by("x DESC").sql(dialect="hive") - 'SELECT x FROM tbl CLUSTER BY x DESC' - - Args: - *expressions: the SQL code strings to parse. - If a `Group` instance is passed, this is used as-is. - If another `Expression` instance is passed, it will be wrapped in a `Cluster`. - append: if `True`, add to any existing expressions. - Otherwise, this flattens all the `Order` expression into a single expression. - dialect: the dialect used to parse the input expression. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - return _apply_child_list_builder( - *expressions, - instance=self, - arg="cluster", - append=append, - copy=copy, - prefix="CLUSTER BY", - into=Cluster, - dialect=dialect, - **opts, - ) - - def select( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - return _apply_list_builder( - *expressions, - instance=self, - arg="expressions", - append=append, - dialect=dialect, - into=Expression, - copy=copy, - **opts, - ) - - def lateral( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - """ - Append to or set the LATERAL expressions. - - Example: - >>> Select().select("x").lateral("OUTER explode(y) tbl2 AS z").from_("tbl").sql() - 'SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z' - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - append: if `True`, add to any existing expressions. - Otherwise, this resets the expressions. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - return _apply_list_builder( - *expressions, - instance=self, - arg="laterals", - append=append, - into=Lateral, - prefix="LATERAL VIEW", - dialect=dialect, - copy=copy, - **opts, - ) - - def join( - self, - expression: ExpOrStr, - on: t.Optional[ExpOrStr] = None, - using: t.Optional[ExpOrStr | t.Collection[ExpOrStr]] = None, - append: bool = True, - join_type: t.Optional[str] = None, - join_alias: t.Optional[Identifier | str] = None, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - """ - Append to or set the JOIN expressions. - - Example: - >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y").sql() - 'SELECT * FROM tbl JOIN tbl2 ON tbl1.y = tbl2.y' - - >>> Select().select("1").from_("a").join("b", using=["x", "y", "z"]).sql() - 'SELECT 1 FROM a JOIN b USING (x, y, z)' - - Use `join_type` to change the type of join: - - >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y", join_type="left outer").sql() - 'SELECT * FROM tbl LEFT OUTER JOIN tbl2 ON tbl1.y = tbl2.y' - - Args: - expression: the SQL code string to parse. - If an `Expression` instance is passed, it will be used as-is. - on: optionally specify the join "on" criteria as a SQL string. - If an `Expression` instance is passed, it will be used as-is. - using: optionally specify the join "using" criteria as a SQL string. - If an `Expression` instance is passed, it will be used as-is. - append: if `True`, add to any existing expressions. - Otherwise, this resets the expressions. - join_type: if set, alter the parsed join type. - join_alias: an optional alias for the joined source. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - Select: the modified expression. - """ - parse_args: t.Dict[str, t.Any] = {"dialect": dialect, **opts} - - try: - expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) - except ParseError: - expression = maybe_parse(expression, into=(Join, Expression), **parse_args) - - join = expression if isinstance(expression, Join) else Join(this=expression) - - if isinstance(join.this, Select): - join.this.replace(join.this.subquery()) - - if join_type: - method: t.Optional[Token] - side: t.Optional[Token] - kind: t.Optional[Token] - - method, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore - - if method: - join.set("method", method.text) - if side: - join.set("side", side.text) - if kind: - join.set("kind", kind.text) - - if on: - on = and_(*ensure_list(on), dialect=dialect, copy=copy, **opts) - join.set("on", on) - - if using: - join = _apply_list_builder( - *ensure_list(using), - instance=join, - arg="using", - append=append, - copy=copy, - into=Identifier, - **opts, - ) - - if join_alias: - join.set("this", alias_(join.this, join_alias, table=True)) - - return _apply_list_builder( - join, - instance=self, - arg="joins", - append=append, - copy=copy, - **opts, - ) - - def having( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - """ - Append to or set the HAVING expressions. - - Example: - >>> Select().select("x", "COUNT(y)").from_("tbl").group_by("x").having("COUNT(y) > 3").sql() - 'SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 3' - - Args: - *expressions: the SQL code strings to parse. - If an `Expression` instance is passed, it will be used as-is. - Multiple expressions are combined with an AND operator. - append: if `True`, AND the new expressions to any existing expression. - Otherwise, this resets the expression. - dialect: the dialect used to parse the input expressions. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input expressions. - - Returns: - The modified Select expression. - """ - return _apply_conjunction_builder( - *expressions, - instance=self, - arg="having", - append=append, - into=Having, - dialect=dialect, - copy=copy, - **opts, - ) - - def window( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - return _apply_list_builder( - *expressions, - instance=self, - arg="windows", - append=append, - into=Window, - dialect=dialect, - copy=copy, - **opts, - ) - - def qualify( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Select: - return _apply_conjunction_builder( - *expressions, - instance=self, - arg="qualify", - append=append, - into=Qualify, - dialect=dialect, - copy=copy, - **opts, - ) - - def distinct( - self, *ons: t.Optional[ExpOrStr], distinct: bool = True, copy: bool = True - ) -> Select: - """ - Set the OFFSET expression. - - Example: - >>> Select().from_("tbl").select("x").distinct().sql() - 'SELECT DISTINCT x FROM tbl' - - Args: - ons: the expressions to distinct on - distinct: whether the Select should be distinct - copy: if `False`, modify this expression instance in-place. - - Returns: - Select: the modified expression. - """ - instance = maybe_copy(self, copy) - on = Tuple(expressions=[maybe_parse(on, copy=copy) for on in ons if on]) if ons else None - instance.set("distinct", Distinct(on=on) if distinct else None) - return instance - - def ctas( - self, - table: ExpOrStr, - properties: t.Optional[t.Dict] = None, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Create: - """ - Convert this expression to a CREATE TABLE AS statement. - - Example: - >>> Select().select("*").from_("tbl").ctas("x").sql() - 'CREATE TABLE x AS SELECT * FROM tbl' - - Args: - table: the SQL code string to parse as the table name. - If another `Expression` instance is passed, it will be used as-is. - properties: an optional mapping of table properties - dialect: the dialect used to parse the input table. - copy: if `False`, modify this expression instance in-place. - opts: other options to use to parse the input table. - - Returns: - The new Create expression. - """ - instance = maybe_copy(self, copy) - table_expression = maybe_parse(table, into=Table, dialect=dialect, **opts) - - properties_expression = None - if properties: - properties_expression = Properties.from_dict(properties) - - return Create( - this=table_expression, - kind="TABLE", - expression=instance, - properties=properties_expression, - ) - - def lock(self, update: bool = True, copy: bool = True) -> Select: - """ - Set the locking read mode for this expression. - - Examples: - >>> Select().select("x").from_("tbl").where("x = 'a'").lock().sql("mysql") - "SELECT x FROM tbl WHERE x = 'a' FOR UPDATE" - - >>> Select().select("x").from_("tbl").where("x = 'a'").lock(update=False).sql("mysql") - "SELECT x FROM tbl WHERE x = 'a' FOR SHARE" - - Args: - update: if `True`, the locking type will be `FOR UPDATE`, else it will be `FOR SHARE`. - copy: if `False`, modify this expression instance in-place. - - Returns: - The modified expression. - """ - inst = maybe_copy(self, copy) - inst.set("locks", [Lock(update=update)]) - - return inst - - def hint(self, *hints: ExpOrStr, dialect: DialectType = None, copy: bool = True) -> Select: - """ - Set hints for this expression. - - Examples: - >>> Select().select("x").from_("tbl").hint("BROADCAST(y)").sql(dialect="spark") - 'SELECT /*+ BROADCAST(y) */ x FROM tbl' - - Args: - hints: The SQL code strings to parse as the hints. - If an `Expression` instance is passed, it will be used as-is. - dialect: The dialect used to parse the hints. - copy: If `False`, modify this expression instance in-place. - - Returns: - The modified expression. - """ - inst = maybe_copy(self, copy) - inst.set( - "hint", Hint(expressions=[maybe_parse(h, copy=copy, dialect=dialect) for h in hints]) - ) - - return inst - - @property - def named_selects(self) -> t.List[str]: - return [e.output_name for e in self.expressions if e.alias_or_name] - - @property - def is_star(self) -> bool: - return any(expression.is_star for expression in self.expressions) - - @property - def selects(self) -> t.List[Expression]: - return self.expressions - - -UNWRAPPED_QUERIES = (Select, SetOperation) - - -class Subquery(DerivedTable, Query): - arg_types = { - "this": True, - "alias": False, - "with": False, - **QUERY_MODIFIERS, - } - - def unnest(self): - """Returns the first non subquery.""" - expression = self - while isinstance(expression, Subquery): - expression = expression.this - return expression - - def unwrap(self) -> Subquery: - expression = self - while expression.same_parent and expression.is_wrapper: - expression = t.cast(Subquery, expression.parent) - return expression - - def select( - self, - *expressions: t.Optional[ExpOrStr], - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, - ) -> Subquery: - this = maybe_copy(self, copy) - this.unnest().select(*expressions, append=append, dialect=dialect, copy=False, **opts) - return this - - @property - def is_wrapper(self) -> bool: - """ - Whether this Subquery acts as a simple wrapper around another expression. - - SELECT * FROM (((SELECT * FROM t))) - ^ - This corresponds to a "wrapper" Subquery node - """ - return all(v is None for k, v in self.args.items() if k != "this") - - @property - def is_star(self) -> bool: - return self.this.is_star - - @property - def output_name(self) -> str: - return self.alias - - -class TableSample(Expression): - arg_types = { - "expressions": False, - "method": False, - "bucket_numerator": False, - "bucket_denominator": False, - "bucket_field": False, - "percent": False, - "rows": False, - "size": False, - "seed": False, - } - - -class Tag(Expression): - """Tags are used for generating arbitrary sql like SELECT x.""" - - arg_types = { - "this": False, - "prefix": False, - "postfix": False, - } - - -# Represents both the standard SQL PIVOT operator and DuckDB's "simplified" PIVOT syntax -# https://duckdb.org/docs/sql/statements/pivot -class Pivot(Expression): - arg_types = { - "this": False, - "alias": False, - "expressions": False, - "fields": False, - "unpivot": False, - "using": False, - "group": False, - "columns": False, - "include_nulls": False, - "default_on_null": False, - "into": False, - } - - @property - def unpivot(self) -> bool: - return bool(self.args.get("unpivot")) - - @property - def fields(self) -> t.List[Expression]: - return self.args.get("fields", []) - - -# https://duckdb.org/docs/sql/statements/unpivot#simplified-unpivot-syntax -# UNPIVOT ... INTO [NAME VALUE ][...,] -class UnpivotColumns(Expression): - arg_types = {"this": True, "expressions": True} - - -class Window(Condition): - arg_types = { - "this": True, - "partition_by": False, - "order": False, - "spec": False, - "alias": False, - "over": False, - "first": False, - } - - -class WindowSpec(Expression): - arg_types = { - "kind": False, - "start": False, - "start_side": False, - "end": False, - "end_side": False, - "exclude": False, - } - - -class PreWhere(Expression): - pass - - -class Where(Expression): - pass - - -class Star(Expression): - arg_types = {"except": False, "replace": False, "rename": False} - - @property - def name(self) -> str: - return "*" - - @property - def output_name(self) -> str: - return self.name - - -class Parameter(Condition): - arg_types = {"this": True, "expression": False} - - -class SessionParameter(Condition): - arg_types = {"this": True, "kind": False} - - -class Placeholder(Condition): - arg_types = {"this": False, "kind": False} - - @property - def name(self) -> str: - return self.this or "?" - - -class Null(Condition): - arg_types: t.Dict[str, t.Any] = {} - - @property - def name(self) -> str: - return "NULL" - - def to_py(self) -> Lit[None]: - return None - - -class Boolean(Condition): - def to_py(self) -> bool: - return self.this - - -class DataTypeParam(Expression): - arg_types = {"this": True, "expression": False} - - @property - def name(self) -> str: - return self.this.name - - -# The `nullable` arg is helpful when transpiling types from other dialects to ClickHouse, which -# assumes non-nullable types by default. Values `None` and `True` mean the type is nullable. -class DataType(Expression): - arg_types = { - "this": True, - "expressions": False, - "nested": False, - "values": False, - "prefix": False, - "kind": False, - "nullable": False, - } - - class Type(AutoName): - ARRAY = auto() - AGGREGATEFUNCTION = auto() - SIMPLEAGGREGATEFUNCTION = auto() - BIGDECIMAL = auto() - BIGINT = auto() - BIGSERIAL = auto() - BINARY = auto() - BIT = auto() - BLOB = auto() - BOOLEAN = auto() - BPCHAR = auto() - CHAR = auto() - DATE = auto() - DATE32 = auto() - DATEMULTIRANGE = auto() - DATERANGE = auto() - DATETIME = auto() - DATETIME2 = auto() - DATETIME64 = auto() - DECIMAL = auto() - DECIMAL32 = auto() - DECIMAL64 = auto() - DECIMAL128 = auto() - DECIMAL256 = auto() - DOUBLE = auto() - DYNAMIC = auto() - ENUM = auto() - ENUM8 = auto() - ENUM16 = auto() - FIXEDSTRING = auto() - FLOAT = auto() - GEOGRAPHY = auto() - GEOMETRY = auto() - POINT = auto() - RING = auto() - LINESTRING = auto() - MULTILINESTRING = auto() - POLYGON = auto() - MULTIPOLYGON = auto() - HLLSKETCH = auto() - HSTORE = auto() - IMAGE = auto() - INET = auto() - INT = auto() - INT128 = auto() - INT256 = auto() - INT4MULTIRANGE = auto() - INT4RANGE = auto() - INT8MULTIRANGE = auto() - INT8RANGE = auto() - INTERVAL = auto() - IPADDRESS = auto() - IPPREFIX = auto() - IPV4 = auto() - IPV6 = auto() - JSON = auto() - JSONB = auto() - LIST = auto() - LONGBLOB = auto() - LONGTEXT = auto() - LOWCARDINALITY = auto() - MAP = auto() - MEDIUMBLOB = auto() - MEDIUMINT = auto() - MEDIUMTEXT = auto() - MONEY = auto() - NAME = auto() - NCHAR = auto() - NESTED = auto() - NOTHING = auto() - NULL = auto() - NUMMULTIRANGE = auto() - NUMRANGE = auto() - NVARCHAR = auto() - OBJECT = auto() - RANGE = auto() - ROWVERSION = auto() - SERIAL = auto() - SET = auto() - SMALLDATETIME = auto() - SMALLINT = auto() - SMALLMONEY = auto() - SMALLSERIAL = auto() - STRUCT = auto() - SUPER = auto() - TEXT = auto() - TINYBLOB = auto() - TINYTEXT = auto() - TIME = auto() - TIMETZ = auto() - TIMESTAMP = auto() - TIMESTAMPNTZ = auto() - TIMESTAMPLTZ = auto() - TIMESTAMPTZ = auto() - TIMESTAMP_S = auto() - TIMESTAMP_MS = auto() - TIMESTAMP_NS = auto() - TINYINT = auto() - TSMULTIRANGE = auto() - TSRANGE = auto() - TSTZMULTIRANGE = auto() - TSTZRANGE = auto() - UBIGINT = auto() - UINT = auto() - UINT128 = auto() - UINT256 = auto() - UMEDIUMINT = auto() - UDECIMAL = auto() - UDOUBLE = auto() - UNION = auto() - UNKNOWN = auto() # Sentinel value, useful for type annotation - USERDEFINED = "USER-DEFINED" - USMALLINT = auto() - UTINYINT = auto() - UUID = auto() - VARBINARY = auto() - VARCHAR = auto() - VARIANT = auto() - VECTOR = auto() - XML = auto() - YEAR = auto() - TDIGEST = auto() - - STRUCT_TYPES = { - Type.NESTED, - Type.OBJECT, - Type.STRUCT, - Type.UNION, - } - - ARRAY_TYPES = { - Type.ARRAY, - Type.LIST, - } - - NESTED_TYPES = { - *STRUCT_TYPES, - *ARRAY_TYPES, - Type.MAP, - } - - TEXT_TYPES = { - Type.CHAR, - Type.NCHAR, - Type.NVARCHAR, - Type.TEXT, - Type.VARCHAR, - Type.NAME, - } - - SIGNED_INTEGER_TYPES = { - Type.BIGINT, - Type.INT, - Type.INT128, - Type.INT256, - Type.MEDIUMINT, - Type.SMALLINT, - Type.TINYINT, - } - - UNSIGNED_INTEGER_TYPES = { - Type.UBIGINT, - Type.UINT, - Type.UINT128, - Type.UINT256, - Type.UMEDIUMINT, - Type.USMALLINT, - Type.UTINYINT, - } - - INTEGER_TYPES = { - *SIGNED_INTEGER_TYPES, - *UNSIGNED_INTEGER_TYPES, - Type.BIT, - } - - FLOAT_TYPES = { - Type.DOUBLE, - Type.FLOAT, - } - - REAL_TYPES = { - *FLOAT_TYPES, - Type.BIGDECIMAL, - Type.DECIMAL, - Type.DECIMAL32, - Type.DECIMAL64, - Type.DECIMAL128, - Type.DECIMAL256, - Type.MONEY, - Type.SMALLMONEY, - Type.UDECIMAL, - Type.UDOUBLE, - } - - NUMERIC_TYPES = { - *INTEGER_TYPES, - *REAL_TYPES, - } - - TEMPORAL_TYPES = { - Type.DATE, - Type.DATE32, - Type.DATETIME, - Type.DATETIME2, - Type.DATETIME64, - Type.SMALLDATETIME, - Type.TIME, - Type.TIMESTAMP, - Type.TIMESTAMPNTZ, - Type.TIMESTAMPLTZ, - Type.TIMESTAMPTZ, - Type.TIMESTAMP_MS, - Type.TIMESTAMP_NS, - Type.TIMESTAMP_S, - Type.TIMETZ, - } - - @classmethod - def build( - cls, - dtype: DATA_TYPE, - dialect: DialectType = None, - udt: bool = False, - copy: bool = True, - **kwargs, - ) -> DataType: - """ - Constructs a DataType object. - - Args: - dtype: the data type of interest. - dialect: the dialect to use for parsing `dtype`, in case it's a string. - udt: when set to True, `dtype` will be used as-is if it can't be parsed into a - DataType, thus creating a user-defined type. - copy: whether to copy the data type. - kwargs: additional arguments to pass in the constructor of DataType. - - Returns: - The constructed DataType object. - """ - from sqlglot import parse_one - - if isinstance(dtype, str): - if dtype.upper() == "UNKNOWN": - return DataType(this=DataType.Type.UNKNOWN, **kwargs) - - try: - data_type_exp = parse_one( - dtype, read=dialect, into=DataType, error_level=ErrorLevel.IGNORE - ) - except ParseError: - if udt: - return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs) - raise - elif isinstance(dtype, (Identifier, Dot)) and udt: - return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs) - elif isinstance(dtype, DataType.Type): - data_type_exp = DataType(this=dtype) - elif isinstance(dtype, DataType): - return maybe_copy(dtype, copy) - else: - raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type") - - return DataType(**{**data_type_exp.args, **kwargs}) - - def is_type(self, *dtypes: DATA_TYPE, check_nullable: bool = False) -> bool: - """ - Checks whether this DataType matches one of the provided data types. Nested types or precision - will be compared using "structural equivalence" semantics, so e.g. array != array. - - Args: - dtypes: the data types to compare this DataType to. - check_nullable: whether to take the NULLABLE type constructor into account for the comparison. - If false, it means that NULLABLE is equivalent to INT. - - Returns: - True, if and only if there is a type in `dtypes` which is equal to this DataType. - """ - self_is_nullable = self.args.get("nullable") - for dtype in dtypes: - other_type = DataType.build(dtype, copy=False, udt=True) - other_is_nullable = other_type.args.get("nullable") - if ( - other_type.expressions - or (check_nullable and (self_is_nullable or other_is_nullable)) - or self.this == DataType.Type.USERDEFINED - or other_type.this == DataType.Type.USERDEFINED - ): - matches = self == other_type - else: - matches = self.this == other_type.this - - if matches: - return True - return False - - -# https://www.postgresql.org/docs/15/datatype-pseudo.html -class PseudoType(DataType): - arg_types = {"this": True} - - -# https://www.postgresql.org/docs/15/datatype-oid.html -class ObjectIdentifier(DataType): - arg_types = {"this": True} - - -# WHERE x EXISTS|ALL|ANY|SOME(SELECT ...) -class SubqueryPredicate(Predicate): - pass - - -class All(SubqueryPredicate): - pass - - -class Any(SubqueryPredicate): - pass - - -# Commands to interact with the databases or engines. For most of the command -# expressions we parse whatever comes after the command's name as a string. -class Command(Expression): - arg_types = {"this": True, "expression": False} - - -class Transaction(Expression): - arg_types = {"this": False, "modes": False, "mark": False} - - -class Commit(Expression): - arg_types = {"chain": False, "this": False, "durability": False} - - -class Rollback(Expression): - arg_types = {"savepoint": False, "this": False} - - -class Alter(Expression): - arg_types = { - "this": True, - "kind": True, - "actions": True, - "exists": False, - "only": False, - "options": False, - "cluster": False, - "not_valid": False, - } - - @property - def kind(self) -> t.Optional[str]: - kind = self.args.get("kind") - return kind and kind.upper() - - @property - def actions(self) -> t.List[Expression]: - return self.args.get("actions") or [] - - -class Analyze(Expression): - arg_types = { - "kind": False, - "this": False, - "options": False, - "mode": False, - "partition": False, - "expression": False, - "properties": False, - } - - -class AnalyzeStatistics(Expression): - arg_types = { - "kind": True, - "option": False, - "this": False, - "expressions": False, - } - - -class AnalyzeHistogram(Expression): - arg_types = { - "this": True, - "expressions": True, - "expression": False, - "update_options": False, - } - - -class AnalyzeSample(Expression): - arg_types = {"kind": True, "sample": True} - - -class AnalyzeListChainedRows(Expression): - arg_types = {"expression": False} - - -class AnalyzeDelete(Expression): - arg_types = {"kind": False} - - -class AnalyzeWith(Expression): - arg_types = {"expressions": True} - - -class AnalyzeValidate(Expression): - arg_types = { - "kind": True, - "this": False, - "expression": False, - } - - -class AnalyzeColumns(Expression): - pass - - -class UsingData(Expression): - pass - - -class AddConstraint(Expression): - arg_types = {"expressions": True} - - -class AddPartition(Expression): - arg_types = {"this": True, "exists": False} - - -class AttachOption(Expression): - arg_types = {"this": True, "expression": False} - - -class DropPartition(Expression): - arg_types = {"expressions": True, "exists": False} - - -# https://clickhouse.com/docs/en/sql-reference/statements/alter/partition#replace-partition -class ReplacePartition(Expression): - arg_types = {"expression": True, "source": True} - - -# Binary expressions like (ADD a b) -class Binary(Condition): - arg_types = {"this": True, "expression": True} - - @property - def left(self) -> Expression: - return self.this - - @property - def right(self) -> Expression: - return self.expression - - -class Add(Binary): - pass - - -class Connector(Binary): - pass - - -class BitwiseAnd(Binary): - pass - - -class BitwiseLeftShift(Binary): - pass - - -class BitwiseOr(Binary): - pass - - -class BitwiseRightShift(Binary): - pass - - -class BitwiseXor(Binary): - pass - - -class Div(Binary): - arg_types = {"this": True, "expression": True, "typed": False, "safe": False} - - -class Overlaps(Binary): - pass - - -class Dot(Binary): - @property - def is_star(self) -> bool: - return self.expression.is_star - - @property - def name(self) -> str: - return self.expression.name - - @property - def output_name(self) -> str: - return self.name - - @classmethod - def build(self, expressions: t.Sequence[Expression]) -> Dot: - """Build a Dot object with a sequence of expressions.""" - if len(expressions) < 2: - raise ValueError("Dot requires >= 2 expressions.") - - return t.cast(Dot, reduce(lambda x, y: Dot(this=x, expression=y), expressions)) - - @property - def parts(self) -> t.List[Expression]: - """Return the parts of a table / column in order catalog, db, table.""" - this, *parts = self.flatten() - - parts.reverse() - - for arg in COLUMN_PARTS: - part = this.args.get(arg) - - if isinstance(part, Expression): - parts.append(part) - - parts.reverse() - return parts - - -DATA_TYPE = t.Union[str, Identifier, Dot, DataType, DataType.Type] - - -class DPipe(Binary): - arg_types = {"this": True, "expression": True, "safe": False} - - -class EQ(Binary, Predicate): - pass - - -class NullSafeEQ(Binary, Predicate): - pass - - -class NullSafeNEQ(Binary, Predicate): - pass - - -# Represents e.g. := in DuckDB which is mostly used for setting parameters -class PropertyEQ(Binary): - pass - - -class Distance(Binary): - pass - - -class Escape(Binary): - pass - - -class Glob(Binary, Predicate): - pass - - -class GT(Binary, Predicate): - pass - - -class GTE(Binary, Predicate): - pass - - -class ILike(Binary, Predicate): - pass - - -class ILikeAny(Binary, Predicate): - pass - - -class IntDiv(Binary): - pass - - -class Is(Binary, Predicate): - pass - - -class Kwarg(Binary): - """Kwarg in special functions like func(kwarg => y).""" - - -class Like(Binary, Predicate): - pass - - -class LikeAny(Binary, Predicate): - pass - - -class LT(Binary, Predicate): - pass - - -class LTE(Binary, Predicate): - pass - - -class Mod(Binary): - pass - - -class Mul(Binary): - pass - - -class NEQ(Binary, Predicate): - pass - - -# https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH -class Operator(Binary): - arg_types = {"this": True, "operator": True, "expression": True} - - -class SimilarTo(Binary, Predicate): - pass - - -class Slice(Binary): - arg_types = {"this": False, "expression": False} - - -class Sub(Binary): - pass - - -# Unary Expressions -# (NOT a) -class Unary(Condition): - pass - - -class BitwiseNot(Unary): - pass - - -class Not(Unary): - pass - - -class Paren(Unary): - @property - def output_name(self) -> str: - return self.this.name - - -class Neg(Unary): - def to_py(self) -> int | Decimal: - if self.is_number: - return self.this.to_py() * -1 - return super().to_py() - - -class Alias(Expression): - arg_types = {"this": True, "alias": False} - - @property - def output_name(self) -> str: - return self.alias - - -# BigQuery requires the UNPIVOT column list aliases to be either strings or ints, but -# other dialects require identifiers. This enables us to transpile between them easily. -class PivotAlias(Alias): - pass - - -# Represents Snowflake's ANY [ ORDER BY ... ] syntax -# https://docs.snowflake.com/en/sql-reference/constructs/pivot -class PivotAny(Expression): - arg_types = {"this": False} - - -class Aliases(Expression): - arg_types = {"this": True, "expressions": True} - - @property - def aliases(self): - return self.expressions - - -# https://docs.aws.amazon.com/redshift/latest/dg/query-super.html -class AtIndex(Expression): - arg_types = {"this": True, "expression": True} - - -class AtTimeZone(Expression): - arg_types = {"this": True, "zone": True} - - -class FromTimeZone(Expression): - arg_types = {"this": True, "zone": True} - - -class Between(Predicate): - arg_types = {"this": True, "low": True, "high": True} - - -class Bracket(Condition): - # https://cloud.google.com/bigquery/docs/reference/standard-sql/operators#array_subscript_operator - arg_types = { - "this": True, - "expressions": True, - "offset": False, - "safe": False, - "returns_list_for_maps": False, - } - - @property - def output_name(self) -> str: - if len(self.expressions) == 1: - return self.expressions[0].output_name - - return super().output_name - - -class Distinct(Expression): - arg_types = {"expressions": False, "on": False} - - -class In(Predicate): - arg_types = { - "this": True, - "expressions": False, - "query": False, - "unnest": False, - "field": False, - "is_global": False, - } - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/procedural-language#for-in -class ForIn(Expression): - arg_types = {"this": True, "expression": True} - - -class TimeUnit(Expression): - """Automatically converts unit arg into a var.""" - - arg_types = {"unit": False} - - UNABBREVIATED_UNIT_NAME = { - "D": "DAY", - "H": "HOUR", - "M": "MINUTE", - "MS": "MILLISECOND", - "NS": "NANOSECOND", - "Q": "QUARTER", - "S": "SECOND", - "US": "MICROSECOND", - "W": "WEEK", - "Y": "YEAR", - } - - VAR_LIKE = (Column, Literal, Var) - - def __init__(self, **args): - unit = args.get("unit") - if isinstance(unit, self.VAR_LIKE): - args["unit"] = Var( - this=(self.UNABBREVIATED_UNIT_NAME.get(unit.name) or unit.name).upper() - ) - elif isinstance(unit, Week): - unit.set("this", Var(this=unit.this.name.upper())) - - super().__init__(**args) - - @property - def unit(self) -> t.Optional[Var | IntervalSpan]: - return self.args.get("unit") - - -class IntervalOp(TimeUnit): - arg_types = {"unit": False, "expression": True} - - def interval(self): - return Interval( - this=self.expression.copy(), - unit=self.unit.copy() if self.unit else None, - ) - - -# https://www.oracletutorial.com/oracle-basics/oracle-interval/ -# https://trino.io/docs/current/language/types.html#interval-day-to-second -# https://docs.databricks.com/en/sql/language-manual/data-types/interval-type.html -class IntervalSpan(DataType): - arg_types = {"this": True, "expression": True} - - -class Interval(TimeUnit): - arg_types = {"this": False, "unit": False} - - -class IgnoreNulls(Expression): - pass - - -class RespectNulls(Expression): - pass - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/aggregate-function-calls#max_min_clause -class HavingMax(Expression): - arg_types = {"this": True, "expression": True, "max": True} - - -# Functions -class Func(Condition): - """ - The base class for all function expressions. - - Attributes: - is_var_len_args (bool): if set to True the last argument defined in arg_types will be - treated as a variable length argument and the argument's value will be stored as a list. - _sql_names (list): the SQL name (1st item in the list) and aliases (subsequent items) for this - function expression. These values are used to map this node to a name during parsing as - well as to provide the function's name during SQL string generation. By default the SQL - name is set to the expression's class name transformed to snake case. - """ - - is_var_len_args = False - - @classmethod - def from_arg_list(cls, args): - if cls.is_var_len_args: - all_arg_keys = list(cls.arg_types) - # If this function supports variable length argument treat the last argument as such. - non_var_len_arg_keys = all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys - num_non_var = len(non_var_len_arg_keys) - - args_dict = {arg_key: arg for arg, arg_key in zip(args, non_var_len_arg_keys)} - args_dict[all_arg_keys[-1]] = args[num_non_var:] - else: - args_dict = {arg_key: arg for arg, arg_key in zip(args, cls.arg_types)} - - return cls(**args_dict) - - @classmethod - def sql_names(cls): - if cls is Func: - raise NotImplementedError( - "SQL name is only supported by concrete function implementations" - ) - if "_sql_names" not in cls.__dict__: - cls._sql_names = [camel_to_snake_case(cls.__name__)] - return cls._sql_names - - @classmethod - def sql_name(cls): - return cls.sql_names()[0] - - @classmethod - def default_parser_mappings(cls): - return {name: cls.from_arg_list for name in cls.sql_names()} - - -class AggFunc(Func): - pass - - -class ArrayRemove(Func): - arg_types = {"this": True, "expression": True} - - -class ParameterizedAgg(AggFunc): - arg_types = {"this": True, "expressions": True, "params": True} - - -class Abs(Func): - pass - - -class ArgMax(AggFunc): - arg_types = {"this": True, "expression": True, "count": False} - _sql_names = ["ARG_MAX", "ARGMAX", "MAX_BY"] - - -class ArgMin(AggFunc): - arg_types = {"this": True, "expression": True, "count": False} - _sql_names = ["ARG_MIN", "ARGMIN", "MIN_BY"] - - -class ApproxTopK(AggFunc): - arg_types = {"this": True, "expression": False, "counters": False} - - -class Flatten(Func): - pass - - -# https://spark.apache.org/docs/latest/api/sql/index.html#transform -class Transform(Func): - arg_types = {"this": True, "expression": True} - - -class Anonymous(Func): - arg_types = {"this": True, "expressions": False} - is_var_len_args = True - - @property - def name(self) -> str: - return self.this if isinstance(self.this, str) else self.this.name - - -class AnonymousAggFunc(AggFunc): - arg_types = {"this": True, "expressions": False} - is_var_len_args = True - - -# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/combinators -class CombinedAggFunc(AnonymousAggFunc): - arg_types = {"this": True, "expressions": False} - - -class CombinedParameterizedAgg(ParameterizedAgg): - arg_types = {"this": True, "expressions": True, "params": True} - - -# https://docs.snowflake.com/en/sql-reference/functions/hll -# https://docs.aws.amazon.com/redshift/latest/dg/r_HLL_function.html -class Hll(AggFunc): - arg_types = {"this": True, "expressions": False} - is_var_len_args = True - - -class ApproxDistinct(AggFunc): - arg_types = {"this": True, "accuracy": False} - _sql_names = ["APPROX_DISTINCT", "APPROX_COUNT_DISTINCT"] - - -class Apply(Func): - arg_types = {"this": True, "expression": True} - - -class Array(Func): - arg_types = {"expressions": False, "bracket_notation": False} - is_var_len_args = True - - -# https://docs.snowflake.com/en/sql-reference/functions/to_array -class ToArray(Func): - pass - - -# https://materialize.com/docs/sql/types/list/ -class List(Func): - arg_types = {"expressions": False} - is_var_len_args = True - - -# String pad, kind True -> LPAD, False -> RPAD -class Pad(Func): - arg_types = {"this": True, "expression": True, "fill_pattern": False, "is_left": True} - - -# https://docs.snowflake.com/en/sql-reference/functions/to_char -# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_CHAR-number.html -class ToChar(Func): - arg_types = {"this": True, "format": False, "nlsparam": False} - - -# https://docs.snowflake.com/en/sql-reference/functions/to_decimal -# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_NUMBER.html -class ToNumber(Func): - arg_types = { - "this": True, - "format": False, - "nlsparam": False, - "precision": False, - "scale": False, - } - - -# https://docs.snowflake.com/en/sql-reference/functions/to_double -class ToDouble(Func): - arg_types = { - "this": True, - "format": False, - } - - -class Columns(Func): - arg_types = {"this": True, "unpack": False} - - -# https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax -class Convert(Func): - arg_types = {"this": True, "expression": True, "style": False} - - -# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/CONVERT.html -class ConvertToCharset(Func): - arg_types = {"this": True, "dest": True, "source": False} - - -class ConvertTimezone(Func): - arg_types = {"source_tz": False, "target_tz": True, "timestamp": True} - - -class GenerateSeries(Func): - arg_types = {"start": True, "end": True, "step": False, "is_end_exclusive": False} - - -# Postgres' GENERATE_SERIES function returns a row set, i.e. it implicitly explodes when it's -# used in a projection, so this expression is a helper that facilitates transpilation to other -# dialects. For example, we'd generate UNNEST(GENERATE_SERIES(...)) in DuckDB -class ExplodingGenerateSeries(GenerateSeries): - pass - - -class ArrayAgg(AggFunc): - arg_types = {"this": True, "nulls_excluded": False} - - -class ArrayUniqueAgg(AggFunc): - pass - - -class ArrayAll(Func): - arg_types = {"this": True, "expression": True} - - -# Represents Python's `any(f(x) for x in array)`, where `array` is `this` and `f` is `expression` -class ArrayAny(Func): - arg_types = {"this": True, "expression": True} - - -class ArrayConcat(Func): - _sql_names = ["ARRAY_CONCAT", "ARRAY_CAT"] - arg_types = {"this": True, "expressions": False} - is_var_len_args = True - - -class ArrayConcatAgg(AggFunc): - pass - - -class ArrayConstructCompact(Func): - arg_types = {"expressions": True} - is_var_len_args = True - - -class ArrayContains(Binary, Func): - _sql_names = ["ARRAY_CONTAINS", "ARRAY_HAS"] - - -class ArrayContainsAll(Binary, Func): - _sql_names = ["ARRAY_CONTAINS_ALL", "ARRAY_HAS_ALL"] - - -class ArrayFilter(Func): - arg_types = {"this": True, "expression": True} - _sql_names = ["FILTER", "ARRAY_FILTER"] - - -class ArrayToString(Func): - arg_types = {"this": True, "expression": True, "null": False} - _sql_names = ["ARRAY_TO_STRING", "ARRAY_JOIN"] - - -class ArrayIntersect(Func): - arg_types = {"expressions": True} - is_var_len_args = True - _sql_names = ["ARRAY_INTERSECT", "ARRAY_INTERSECTION"] - - -class StPoint(Func): - arg_types = {"this": True, "expression": True, "null": False} - _sql_names = ["ST_POINT", "ST_MAKEPOINT"] - - -class StDistance(Func): - arg_types = {"this": True, "expression": True, "use_spheroid": False} - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/timestamp_functions#string -class String(Func): - arg_types = {"this": True, "zone": False} - - -class StringToArray(Func): - arg_types = {"this": True, "expression": False, "null": False} - _sql_names = ["STRING_TO_ARRAY", "SPLIT_BY_STRING", "STRTOK_TO_ARRAY"] - - -class ArrayOverlaps(Binary, Func): - pass - - -class ArraySize(Func): - arg_types = {"this": True, "expression": False} - _sql_names = ["ARRAY_SIZE", "ARRAY_LENGTH"] - - -class ArraySort(Func): - arg_types = {"this": True, "expression": False} - - -class ArraySum(Func): - arg_types = {"this": True, "expression": False} - - -class ArrayUnionAgg(AggFunc): - pass - - -class Avg(AggFunc): - pass - - -class AnyValue(AggFunc): - pass - - -class Lag(AggFunc): - arg_types = {"this": True, "offset": False, "default": False} - - -class Lead(AggFunc): - arg_types = {"this": True, "offset": False, "default": False} - - -# some dialects have a distinction between first and first_value, usually first is an aggregate func -# and first_value is a window func -class First(AggFunc): - pass - - -class Last(AggFunc): - pass - - -class FirstValue(AggFunc): - pass - - -class LastValue(AggFunc): - pass - - -class NthValue(AggFunc): - arg_types = {"this": True, "offset": True} - - -class Case(Func): - arg_types = {"this": False, "ifs": True, "default": False} - - def when(self, condition: ExpOrStr, then: ExpOrStr, copy: bool = True, **opts) -> Case: - instance = maybe_copy(self, copy) - instance.append( - "ifs", - If( - this=maybe_parse(condition, copy=copy, **opts), - true=maybe_parse(then, copy=copy, **opts), - ), - ) - return instance - - def else_(self, condition: ExpOrStr, copy: bool = True, **opts) -> Case: - instance = maybe_copy(self, copy) - instance.set("default", maybe_parse(condition, copy=copy, **opts)) - return instance - - -class Cast(Func): - arg_types = { - "this": True, - "to": True, - "format": False, - "safe": False, - "action": False, - "default": False, - } - - @property - def name(self) -> str: - return self.this.name - - @property - def to(self) -> DataType: - return self.args["to"] - - @property - def output_name(self) -> str: - return self.name - - def is_type(self, *dtypes: DATA_TYPE) -> bool: - """ - Checks whether this Cast's DataType matches one of the provided data types. Nested types - like arrays or structs will be compared using "structural equivalence" semantics, so e.g. - array != array. - - Args: - dtypes: the data types to compare this Cast's DataType to. - - Returns: - True, if and only if there is a type in `dtypes` which is equal to this Cast's DataType. - """ - return self.to.is_type(*dtypes) - - -class TryCast(Cast): - pass - - -# https://clickhouse.com/docs/sql-reference/data-types/newjson#reading-json-paths-as-sub-columns -class JSONCast(Cast): - pass - - -class Try(Func): - pass - - -class CastToStrType(Func): - arg_types = {"this": True, "to": True} - - -# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/SQL-Functions-Expressions-and-Predicates/String-Operators-and-Functions/TRANSLATE/TRANSLATE-Function-Syntax -class TranslateCharacters(Expression): - arg_types = {"this": True, "expression": True, "with_error": False} - - -class Collate(Binary, Func): - pass - - -class Ceil(Func): - arg_types = {"this": True, "decimals": False, "to": False} - _sql_names = ["CEIL", "CEILING"] - - -class Coalesce(Func): - arg_types = {"this": True, "expressions": False, "is_nvl": False, "is_null": False} - is_var_len_args = True - _sql_names = ["COALESCE", "IFNULL", "NVL"] - - -class Chr(Func): - arg_types = {"expressions": True, "charset": False} - is_var_len_args = True - _sql_names = ["CHR", "CHAR"] - - -class Concat(Func): - arg_types = {"expressions": True, "safe": False, "coalesce": False} - is_var_len_args = True - - -class ConcatWs(Concat): - _sql_names = ["CONCAT_WS"] - - -class Contains(Func): - arg_types = {"this": True, "expression": True} - - -# https://docs.oracle.com/cd/B13789_01/server.101/b10759/operators004.htm#i1035022 -class ConnectByRoot(Func): - pass - - -class Count(AggFunc): - arg_types = {"this": False, "expressions": False, "big_int": False} - is_var_len_args = True - - -class CountIf(AggFunc): - _sql_names = ["COUNT_IF", "COUNTIF"] - - -# cube root -class Cbrt(Func): - pass - - -class CurrentDate(Func): - arg_types = {"this": False} - - -class CurrentDatetime(Func): - arg_types = {"this": False} - - -class CurrentTime(Func): - arg_types = {"this": False} - - -class CurrentTimestamp(Func): - arg_types = {"this": False, "sysdate": False} - - -class CurrentSchema(Func): - arg_types = {"this": False} - - -class CurrentUser(Func): - arg_types = {"this": False} - - -class DateAdd(Func, IntervalOp): - arg_types = {"this": True, "expression": True, "unit": False} - - -class DateBin(Func, IntervalOp): - arg_types = {"this": True, "expression": True, "unit": False, "zone": False} - - -class DateSub(Func, IntervalOp): - arg_types = {"this": True, "expression": True, "unit": False} - - -class DateDiff(Func, TimeUnit): - _sql_names = ["DATEDIFF", "DATE_DIFF"] - arg_types = {"this": True, "expression": True, "unit": False, "zone": False} - - -class DateTrunc(Func): - arg_types = {"unit": True, "this": True, "zone": False} - - def __init__(self, **args): - # Across most dialects it's safe to unabbreviate the unit (e.g. 'Q' -> 'QUARTER') except Oracle - # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html - unabbreviate = args.pop("unabbreviate", True) - - unit = args.get("unit") - if isinstance(unit, TimeUnit.VAR_LIKE): - unit_name = unit.name.upper() - if unabbreviate and unit_name in TimeUnit.UNABBREVIATED_UNIT_NAME: - unit_name = TimeUnit.UNABBREVIATED_UNIT_NAME[unit_name] - - args["unit"] = Literal.string(unit_name) - elif isinstance(unit, Week): - unit.set("this", Literal.string(unit.this.name.upper())) - - super().__init__(**args) - - @property - def unit(self) -> Expression: - return self.args["unit"] - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/datetime_functions#datetime -# expression can either be time_expr or time_zone -class Datetime(Func): - arg_types = {"this": True, "expression": False} - - -class DatetimeAdd(Func, IntervalOp): - arg_types = {"this": True, "expression": True, "unit": False} - - -class DatetimeSub(Func, IntervalOp): - arg_types = {"this": True, "expression": True, "unit": False} - - -class DatetimeDiff(Func, TimeUnit): - arg_types = {"this": True, "expression": True, "unit": False} - - -class DatetimeTrunc(Func, TimeUnit): - arg_types = {"this": True, "unit": True, "zone": False} - - -class DayOfWeek(Func): - _sql_names = ["DAY_OF_WEEK", "DAYOFWEEK"] - - -# https://duckdb.org/docs/sql/functions/datepart.html#part-specifiers-only-usable-as-date-part-specifiers -# ISO day of week function in duckdb is ISODOW -class DayOfWeekIso(Func): - _sql_names = ["DAYOFWEEK_ISO", "ISODOW"] - - -class DayOfMonth(Func): - _sql_names = ["DAY_OF_MONTH", "DAYOFMONTH"] - - -class DayOfYear(Func): - _sql_names = ["DAY_OF_YEAR", "DAYOFYEAR"] - - -class ToDays(Func): - pass - - -class WeekOfYear(Func): - _sql_names = ["WEEK_OF_YEAR", "WEEKOFYEAR"] - - -class MonthsBetween(Func): - arg_types = {"this": True, "expression": True, "roundoff": False} - - -class MakeInterval(Func): - arg_types = { - "year": False, - "month": False, - "day": False, - "hour": False, - "minute": False, - "second": False, - } - - -class LastDay(Func, TimeUnit): - _sql_names = ["LAST_DAY", "LAST_DAY_OF_MONTH"] - arg_types = {"this": True, "unit": False} - - -class Extract(Func): - arg_types = {"this": True, "expression": True} - - -class Exists(Func, SubqueryPredicate): - arg_types = {"this": True, "expression": False} - - -class Timestamp(Func): - arg_types = {"this": False, "zone": False, "with_tz": False} - - -class TimestampAdd(Func, TimeUnit): - arg_types = {"this": True, "expression": True, "unit": False} - - -class TimestampSub(Func, TimeUnit): - arg_types = {"this": True, "expression": True, "unit": False} - - -class TimestampDiff(Func, TimeUnit): - _sql_names = ["TIMESTAMPDIFF", "TIMESTAMP_DIFF"] - arg_types = {"this": True, "expression": True, "unit": False} - - -class TimestampTrunc(Func, TimeUnit): - arg_types = {"this": True, "unit": True, "zone": False} - - -class TimeAdd(Func, TimeUnit): - arg_types = {"this": True, "expression": True, "unit": False} - - -class TimeSub(Func, TimeUnit): - arg_types = {"this": True, "expression": True, "unit": False} - - -class TimeDiff(Func, TimeUnit): - arg_types = {"this": True, "expression": True, "unit": False} - - -class TimeTrunc(Func, TimeUnit): - arg_types = {"this": True, "unit": True, "zone": False} - - -class DateFromParts(Func): - _sql_names = ["DATE_FROM_PARTS", "DATEFROMPARTS"] - arg_types = {"year": True, "month": True, "day": True} - - -class TimeFromParts(Func): - _sql_names = ["TIME_FROM_PARTS", "TIMEFROMPARTS"] - arg_types = { - "hour": True, - "min": True, - "sec": True, - "nano": False, - "fractions": False, - "precision": False, - } - - -class DateStrToDate(Func): - pass - - -class DateToDateStr(Func): - pass - - -class DateToDi(Func): - pass - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#date -class Date(Func): - arg_types = {"this": False, "zone": False, "expressions": False} - is_var_len_args = True - - -class Day(Func): - pass - - -class Decode(Func): - arg_types = {"this": True, "charset": True, "replace": False} - - -class DiToDate(Func): - pass - - -class Encode(Func): - arg_types = {"this": True, "charset": True} - - -class Exp(Func): - pass - - -# https://docs.snowflake.com/en/sql-reference/functions/flatten -class Explode(Func, UDTF): - arg_types = {"this": True, "expressions": False} - is_var_len_args = True - - -# https://spark.apache.org/docs/latest/api/sql/#inline -class Inline(Func): - pass - - -class ExplodeOuter(Explode): - pass - - -class Posexplode(Explode): - pass - - -class PosexplodeOuter(Posexplode, ExplodeOuter): - pass - - -class Unnest(Func, UDTF): - arg_types = { - "expressions": True, - "alias": False, - "offset": False, - "explode_array": False, - } - - @property - def selects(self) -> t.List[Expression]: - columns = super().selects - offset = self.args.get("offset") - if offset: - columns = columns + [to_identifier("offset") if offset is True else offset] - return columns - - -class Floor(Func): - arg_types = {"this": True, "decimals": False, "to": False} - - -class FromBase64(Func): - pass - - -class FeaturesAtTime(Func): - arg_types = {"this": True, "time": False, "num_rows": False, "ignore_feature_nulls": False} - - -class ToBase64(Func): - pass - - -# https://trino.io/docs/current/functions/datetime.html#from_iso8601_timestamp -class FromISO8601Timestamp(Func): - _sql_names = ["FROM_ISO8601_TIMESTAMP"] - - -class GapFill(Func): - arg_types = { - "this": True, - "ts_column": True, - "bucket_width": True, - "partitioning_columns": False, - "value_columns": False, - "origin": False, - "ignore_nulls": False, - } - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/array_functions#generate_date_array -class GenerateDateArray(Func): - arg_types = {"start": True, "end": True, "step": False} - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/array_functions#generate_timestamp_array -class GenerateTimestampArray(Func): - arg_types = {"start": True, "end": True, "step": True} - - -class Greatest(Func): - arg_types = {"this": True, "expressions": False} - is_var_len_args = True - - -# Trino's `ON OVERFLOW TRUNCATE [filler_string] {WITH | WITHOUT} COUNT` -# https://trino.io/docs/current/functions/aggregate.html#listagg -class OverflowTruncateBehavior(Expression): - arg_types = {"this": False, "with_count": True} - - -class GroupConcat(AggFunc): - arg_types = {"this": True, "separator": False, "on_overflow": False} - - -class Hex(Func): - pass - - -class LowerHex(Hex): - pass - - -class And(Connector, Func): - pass - - -class Or(Connector, Func): - pass - - -class Xor(Connector, Func): - arg_types = {"this": False, "expression": False, "expressions": False} - - -class If(Func): - arg_types = {"this": True, "true": True, "false": False} - _sql_names = ["IF", "IIF"] - - -class Nullif(Func): - arg_types = {"this": True, "expression": True} - - -class Initcap(Func): - arg_types = {"this": True, "expression": False} - - -class IsAscii(Func): - pass - - -class IsNan(Func): - _sql_names = ["IS_NAN", "ISNAN"] - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/json_functions#int64_for_json -class Int64(Func): - pass - - -class IsInf(Func): - _sql_names = ["IS_INF", "ISINF"] - - -# https://www.postgresql.org/docs/current/functions-json.html -class JSON(Expression): - arg_types = {"this": False, "with": False, "unique": False} - - -class JSONPath(Expression): - arg_types = {"expressions": True, "escape": False} - - @property - def output_name(self) -> str: - last_segment = self.expressions[-1].this - return last_segment if isinstance(last_segment, str) else "" - - -class JSONPathPart(Expression): - arg_types = {} - - -class JSONPathFilter(JSONPathPart): - arg_types = {"this": True} - - -class JSONPathKey(JSONPathPart): - arg_types = {"this": True} - - -class JSONPathRecursive(JSONPathPart): - arg_types = {"this": False} - - -class JSONPathRoot(JSONPathPart): - pass - - -class JSONPathScript(JSONPathPart): - arg_types = {"this": True} - - -class JSONPathSlice(JSONPathPart): - arg_types = {"start": False, "end": False, "step": False} - - -class JSONPathSelector(JSONPathPart): - arg_types = {"this": True} - - -class JSONPathSubscript(JSONPathPart): - arg_types = {"this": True} - - -class JSONPathUnion(JSONPathPart): - arg_types = {"expressions": True} - - -class JSONPathWildcard(JSONPathPart): - pass - - -class FormatJson(Expression): - pass - - -class JSONKeyValue(Expression): - arg_types = {"this": True, "expression": True} - - -class JSONObject(Func): - arg_types = { - "expressions": False, - "null_handling": False, - "unique_keys": False, - "return_type": False, - "encoding": False, - } - - -class JSONObjectAgg(AggFunc): - arg_types = { - "expressions": False, - "null_handling": False, - "unique_keys": False, - "return_type": False, - "encoding": False, - } - - -# https://www.postgresql.org/docs/9.5/functions-aggregate.html -class JSONBObjectAgg(AggFunc): - arg_types = {"this": True, "expression": True} - - -# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_ARRAY.html -class JSONArray(Func): - arg_types = { - "expressions": True, - "null_handling": False, - "return_type": False, - "strict": False, - } - - -# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_ARRAYAGG.html -class JSONArrayAgg(Func): - arg_types = { - "this": True, - "order": False, - "null_handling": False, - "return_type": False, - "strict": False, - } - - -class JSONExists(Func): - arg_types = {"this": True, "path": True, "passing": False, "on_condition": False} - - -# https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html -# Note: parsing of JSON column definitions is currently incomplete. -class JSONColumnDef(Expression): - arg_types = {"this": False, "kind": False, "path": False, "nested_schema": False} - - -class JSONSchema(Expression): - arg_types = {"expressions": True} - - -# https://dev.mysql.com/doc/refman/8.4/en/json-search-functions.html#function_json-value -class JSONValue(Expression): - arg_types = { - "this": True, - "path": True, - "returning": False, - "on_condition": False, - } - - -class JSONValueArray(Func): - arg_types = {"this": True, "expression": False} - - -# # https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/JSON_TABLE.html -class JSONTable(Func): - arg_types = { - "this": True, - "schema": True, - "path": False, - "error_handling": False, - "empty_handling": False, - } - - -# https://docs.snowflake.com/en/sql-reference/functions/object_insert -class ObjectInsert(Func): - arg_types = { - "this": True, - "key": True, - "value": True, - "update_flag": False, - } - - -class OpenJSONColumnDef(Expression): - arg_types = {"this": True, "kind": True, "path": False, "as_json": False} - - -class OpenJSON(Func): - arg_types = {"this": True, "path": False, "expressions": False} - - -class JSONBContains(Binary, Func): - _sql_names = ["JSONB_CONTAINS"] - - -class JSONBExists(Func): - arg_types = {"this": True, "path": True} - _sql_names = ["JSONB_EXISTS"] - - -class JSONExtract(Binary, Func): - arg_types = { - "this": True, - "expression": True, - "only_json_types": False, - "expressions": False, - "variant_extract": False, - "json_query": False, - "option": False, - "quote": False, - "on_condition": False, - } - _sql_names = ["JSON_EXTRACT"] - is_var_len_args = True - - @property - def output_name(self) -> str: - return self.expression.output_name if not self.expressions else "" - - -# https://trino.io/docs/current/functions/json.html#json-query -class JSONExtractQuote(Expression): - arg_types = { - "option": True, - "scalar": False, - } - - -class JSONExtractArray(Func): - arg_types = {"this": True, "expression": False} - _sql_names = ["JSON_EXTRACT_ARRAY"] - - -class JSONExtractScalar(Binary, Func): - arg_types = {"this": True, "expression": True, "only_json_types": False, "expressions": False} - _sql_names = ["JSON_EXTRACT_SCALAR"] - is_var_len_args = True - - @property - def output_name(self) -> str: - return self.expression.output_name - - -class JSONBExtract(Binary, Func): - _sql_names = ["JSONB_EXTRACT"] - - -class JSONBExtractScalar(Binary, Func): - _sql_names = ["JSONB_EXTRACT_SCALAR"] - - -class JSONFormat(Func): - arg_types = {"this": False, "options": False, "is_json": False} - _sql_names = ["JSON_FORMAT"] - - -# https://dev.mysql.com/doc/refman/8.0/en/json-search-functions.html#operator_member-of -class JSONArrayContains(Binary, Predicate, Func): - _sql_names = ["JSON_ARRAY_CONTAINS"] - - -class ParseJSON(Func): - # BigQuery, Snowflake have PARSE_JSON, Presto has JSON_PARSE - # Snowflake also has TRY_PARSE_JSON, which is represented using `safe` - _sql_names = ["PARSE_JSON", "JSON_PARSE"] - arg_types = {"this": True, "expression": False, "safe": False} - - -class Least(Func): - arg_types = {"this": True, "expressions": False} - is_var_len_args = True - - -class Left(Func): - arg_types = {"this": True, "expression": True} - - -class Right(Func): - arg_types = {"this": True, "expression": True} - - -class Length(Func): - arg_types = {"this": True, "binary": False, "encoding": False} - _sql_names = ["LENGTH", "LEN", "CHAR_LENGTH", "CHARACTER_LENGTH"] - - -class Levenshtein(Func): - arg_types = { - "this": True, - "expression": False, - "ins_cost": False, - "del_cost": False, - "sub_cost": False, - "max_dist": False, - } - - -class Ln(Func): - pass - - -class Log(Func): - arg_types = {"this": True, "expression": False} - - -class LogicalOr(AggFunc): - _sql_names = ["LOGICAL_OR", "BOOL_OR", "BOOLOR_AGG"] - - -class LogicalAnd(AggFunc): - _sql_names = ["LOGICAL_AND", "BOOL_AND", "BOOLAND_AGG"] - - -class Lower(Func): - _sql_names = ["LOWER", "LCASE"] - - -class Map(Func): - arg_types = {"keys": False, "values": False} - - @property - def keys(self) -> t.List[Expression]: - keys = self.args.get("keys") - return keys.expressions if keys else [] - - @property - def values(self) -> t.List[Expression]: - values = self.args.get("values") - return values.expressions if values else [] - - -# Represents the MAP {...} syntax in DuckDB - basically convert a struct to a MAP -class ToMap(Func): - pass - - -class MapFromEntries(Func): - pass - - -# https://learn.microsoft.com/en-us/sql/t-sql/language-elements/scope-resolution-operator-transact-sql?view=sql-server-ver16 -class ScopeResolution(Expression): - arg_types = {"this": False, "expression": True} - - -class Stream(Expression): - pass - - -class StarMap(Func): - pass - - -class VarMap(Func): - arg_types = {"keys": True, "values": True} - is_var_len_args = True - - @property - def keys(self) -> t.List[Expression]: - return self.args["keys"].expressions - - @property - def values(self) -> t.List[Expression]: - return self.args["values"].expressions - - -# https://dev.mysql.com/doc/refman/8.0/en/fulltext-search.html -class MatchAgainst(Func): - arg_types = {"this": True, "expressions": True, "modifier": False} - - -class Max(AggFunc): - arg_types = {"this": True, "expressions": False} - is_var_len_args = True - - -class MD5(Func): - _sql_names = ["MD5"] - - -# Represents the variant of the MD5 function that returns a binary value -class MD5Digest(Func): - _sql_names = ["MD5_DIGEST"] - - -class Median(AggFunc): - pass - - -class Min(AggFunc): - arg_types = {"this": True, "expressions": False} - is_var_len_args = True - - -class Month(Func): - pass - - -class AddMonths(Func): - arg_types = {"this": True, "expression": True} - - -class Nvl2(Func): - arg_types = {"this": True, "true": True, "false": False} - - -class Normalize(Func): - arg_types = {"this": True, "form": False} - - -class Overlay(Func): - arg_types = {"this": True, "expression": True, "from": True, "for": False} - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict#mlpredict_function -class Predict(Func): - arg_types = {"this": True, "expression": True, "params_struct": False} - - -class Pow(Binary, Func): - _sql_names = ["POWER", "POW"] - - -class PercentileCont(AggFunc): - arg_types = {"this": True, "expression": False} - - -class PercentileDisc(AggFunc): - arg_types = {"this": True, "expression": False} - - -class Quantile(AggFunc): - arg_types = {"this": True, "quantile": True} - - -class ApproxQuantile(Quantile): - arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False} - - -class Quarter(Func): - pass - - -# https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/SQL-Functions-Expressions-and-Predicates/Arithmetic-Trigonometric-Hyperbolic-Operators/Functions/RANDOM/RANDOM-Function-Syntax -# teradata lower and upper bounds -class Rand(Func): - _sql_names = ["RAND", "RANDOM"] - arg_types = {"this": False, "lower": False, "upper": False} - - -class Randn(Func): - arg_types = {"this": False} - - -class RangeN(Func): - arg_types = {"this": True, "expressions": True, "each": False} - - -class ReadCSV(Func): - _sql_names = ["READ_CSV"] - is_var_len_args = True - arg_types = {"this": True, "expressions": False} - - -class Reduce(Func): - arg_types = {"this": True, "initial": True, "merge": True, "finish": False} - - -class RegexpExtract(Func): - arg_types = { - "this": True, - "expression": True, - "position": False, - "occurrence": False, - "parameters": False, - "group": False, - } - - -class RegexpExtractAll(Func): - arg_types = { - "this": True, - "expression": True, - "position": False, - "occurrence": False, - "parameters": False, - "group": False, - } - - -class RegexpReplace(Func): - arg_types = { - "this": True, - "expression": True, - "replacement": False, - "position": False, - "occurrence": False, - "modifiers": False, - } - - -class RegexpLike(Binary, Func): - arg_types = {"this": True, "expression": True, "flag": False} - - -class RegexpILike(Binary, Func): - arg_types = {"this": True, "expression": True, "flag": False} - - -# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split.html -# limit is the number of times a pattern is applied -class RegexpSplit(Func): - arg_types = {"this": True, "expression": True, "limit": False} - - -class Repeat(Func): - arg_types = {"this": True, "times": True} - - -# https://learn.microsoft.com/en-us/sql/t-sql/functions/round-transact-sql?view=sql-server-ver16 -# tsql third argument function == trunctaion if not 0 -class Round(Func): - arg_types = {"this": True, "decimals": False, "truncate": False} - - -class RowNumber(Func): - arg_types = {"this": False} - - -class SafeDivide(Func): - arg_types = {"this": True, "expression": True} - - -class SHA(Func): - _sql_names = ["SHA", "SHA1"] - - -class SHA2(Func): - _sql_names = ["SHA2"] - arg_types = {"this": True, "length": False} - - -class Sign(Func): - _sql_names = ["SIGN", "SIGNUM"] - - -class SortArray(Func): - arg_types = {"this": True, "asc": False} - - -class Split(Func): - arg_types = {"this": True, "expression": True, "limit": False} - - -# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split_part.html -class SplitPart(Func): - arg_types = {"this": True, "delimiter": True, "part_index": True} - - -# Start may be omitted in the case of postgres -# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 -class Substring(Func): - _sql_names = ["SUBSTRING", "SUBSTR"] - arg_types = {"this": True, "start": False, "length": False} - - -class StandardHash(Func): - arg_types = {"this": True, "expression": False} - - -class StartsWith(Func): - _sql_names = ["STARTS_WITH", "STARTSWITH"] - arg_types = {"this": True, "expression": True} - - -class EndsWith(Func): - _sql_names = ["ENDS_WITH", "ENDSWITH"] - arg_types = {"this": True, "expression": True} - - -class StrPosition(Func): - arg_types = { - "this": True, - "substr": True, - "position": False, - "occurrence": False, - } - - -class StrToDate(Func): - arg_types = {"this": True, "format": False, "safe": False} - - -class StrToTime(Func): - arg_types = {"this": True, "format": True, "zone": False, "safe": False} - - -# Spark allows unix_timestamp() -# https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.unix_timestamp.html -class StrToUnix(Func): - arg_types = {"this": False, "format": False} - - -# https://prestodb.io/docs/current/functions/string.html -# https://spark.apache.org/docs/latest/api/sql/index.html#str_to_map -class StrToMap(Func): - arg_types = { - "this": True, - "pair_delim": False, - "key_value_delim": False, - "duplicate_resolution_callback": False, - } - - -class NumberToStr(Func): - arg_types = {"this": True, "format": True, "culture": False} - - -class FromBase(Func): - arg_types = {"this": True, "expression": True} - - -class Struct(Func): - arg_types = {"expressions": False} - is_var_len_args = True - - -class StructExtract(Func): - arg_types = {"this": True, "expression": True} - - -# https://learn.microsoft.com/en-us/sql/t-sql/functions/stuff-transact-sql?view=sql-server-ver16 -# https://docs.snowflake.com/en/sql-reference/functions/insert -class Stuff(Func): - _sql_names = ["STUFF", "INSERT"] - arg_types = {"this": True, "start": True, "length": True, "expression": True} - - -class Sum(AggFunc): - pass - - -class Sqrt(Func): - pass - - -class Stddev(AggFunc): - _sql_names = ["STDDEV", "STDEV"] - - -class StddevPop(AggFunc): - pass - - -class StddevSamp(AggFunc): - pass - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/time_functions#time -class Time(Func): - arg_types = {"this": False, "zone": False} - - -class TimeToStr(Func): - arg_types = {"this": True, "format": True, "culture": False, "zone": False} - - -class TimeToTimeStr(Func): - pass - - -class TimeToUnix(Func): - pass - - -class TimeStrToDate(Func): - pass - - -class TimeStrToTime(Func): - arg_types = {"this": True, "zone": False} - - -class TimeStrToUnix(Func): - pass - - -class Trim(Func): - arg_types = { - "this": True, - "expression": False, - "position": False, - "collation": False, - } - - -class TsOrDsAdd(Func, TimeUnit): - # return_type is used to correctly cast the arguments of this expression when transpiling it - arg_types = {"this": True, "expression": True, "unit": False, "return_type": False} - - @property - def return_type(self) -> DataType: - return DataType.build(self.args.get("return_type") or DataType.Type.DATE) - - -class TsOrDsDiff(Func, TimeUnit): - arg_types = {"this": True, "expression": True, "unit": False} - - -class TsOrDsToDateStr(Func): - pass - - -class TsOrDsToDate(Func): - arg_types = {"this": True, "format": False, "safe": False} - - -class TsOrDsToDatetime(Func): - pass - - -class TsOrDsToTime(Func): - arg_types = {"this": True, "format": False, "safe": False} - - -class TsOrDsToTimestamp(Func): - pass - - -class TsOrDiToDi(Func): - pass - - -class Unhex(Func): - arg_types = {"this": True, "expression": False} - - -class Unicode(Func): - pass - - -# https://cloud.google.com/bigquery/docs/reference/standard-sql/date_functions#unix_date -class UnixDate(Func): - pass - - -class UnixToStr(Func): - arg_types = {"this": True, "format": False} - - -# https://prestodb.io/docs/current/functions/datetime.html -# presto has weird zone/hours/minutes -class UnixToTime(Func): - arg_types = { - "this": True, - "scale": False, - "zone": False, - "hours": False, - "minutes": False, - "format": False, - } - - SECONDS = Literal.number(0) - DECIS = Literal.number(1) - CENTIS = Literal.number(2) - MILLIS = Literal.number(3) - DECIMILLIS = Literal.number(4) - CENTIMILLIS = Literal.number(5) - MICROS = Literal.number(6) - DECIMICROS = Literal.number(7) - CENTIMICROS = Literal.number(8) - NANOS = Literal.number(9) - - -class UnixToTimeStr(Func): - pass - - -class UnixSeconds(Func): - pass - - -class Uuid(Func): - _sql_names = ["UUID", "GEN_RANDOM_UUID", "GENERATE_UUID", "UUID_STRING"] - - arg_types = {"this": False, "name": False} - - -class TimestampFromParts(Func): - _sql_names = ["TIMESTAMP_FROM_PARTS", "TIMESTAMPFROMPARTS"] - arg_types = { - "year": True, - "month": True, - "day": True, - "hour": True, - "min": True, - "sec": True, - "nano": False, - "zone": False, - "milli": False, - } - - -class Upper(Func): - _sql_names = ["UPPER", "UCASE"] - - -class Corr(Binary, AggFunc): - pass - - -class Variance(AggFunc): - _sql_names = ["VARIANCE", "VARIANCE_SAMP", "VAR_SAMP"] - - -class VariancePop(AggFunc): - _sql_names = ["VARIANCE_POP", "VAR_POP"] - - -class CovarSamp(Binary, AggFunc): - pass - - -class CovarPop(Binary, AggFunc): - pass - - -class Week(Func): - arg_types = {"this": True, "mode": False} - - -class XMLElement(Func): - _sql_names = ["XMLELEMENT"] - arg_types = {"this": True, "expressions": False} - - -class XMLTable(Func): - arg_types = { - "this": True, - "namespaces": False, - "passing": False, - "columns": False, - "by_ref": False, - } - - -class XMLNamespace(Expression): - pass - - -# https://learn.microsoft.com/en-us/sql/t-sql/queries/select-for-clause-transact-sql?view=sql-server-ver17#syntax -class XMLKeyValueOption(Expression): - arg_types = {"this": True, "expression": False} - - -class Year(Func): - pass - - -class Use(Expression): - arg_types = {"this": False, "expressions": False, "kind": False} - - -class Merge(DML): - arg_types = { - "this": True, - "using": True, - "on": True, - "whens": True, - "with": False, - "returning": False, - } - - -class When(Expression): - arg_types = {"matched": True, "source": False, "condition": False, "then": True} - - -class Whens(Expression): - """Wraps around one or more WHEN [NOT] MATCHED [...] clauses.""" - - arg_types = {"expressions": True} - - -# https://docs.oracle.com/javadb/10.8.3.0/ref/rrefsqljnextvaluefor.html -# https://learn.microsoft.com/en-us/sql/t-sql/functions/next-value-for-transact-sql?view=sql-server-ver16 -class NextValueFor(Func): - arg_types = {"this": True, "order": False} - - -# Refers to a trailing semi-colon. This is only used to preserve trailing comments -# select 1; -- my comment -class Semicolon(Expression): - arg_types = {} - - -def _norm_arg(arg): - return arg.lower() if type(arg) is str else arg - - -ALL_FUNCTIONS = subclasses(__name__, Func, (AggFunc, Anonymous, Func)) -FUNCTION_BY_NAME = {name: func for func in ALL_FUNCTIONS for name in func.sql_names()} - -JSON_PATH_PARTS = subclasses(__name__, JSONPathPart, (JSONPathPart,)) - -PERCENTILES = (PercentileCont, PercentileDisc) - - -# Helpers -@t.overload -def maybe_parse( - sql_or_expression: ExpOrStr, - *, - into: t.Type[E], - dialect: DialectType = None, - prefix: t.Optional[str] = None, - copy: bool = False, - **opts, -) -> E: ... - - -@t.overload -def maybe_parse( - sql_or_expression: str | E, - *, - into: t.Optional[IntoType] = None, - dialect: DialectType = None, - prefix: t.Optional[str] = None, - copy: bool = False, - **opts, -) -> E: ... - - -def maybe_parse( - sql_or_expression: ExpOrStr, - *, - into: t.Optional[IntoType] = None, - dialect: DialectType = None, - prefix: t.Optional[str] = None, - copy: bool = False, - **opts, -) -> Expression: - """Gracefully handle a possible string or expression. - - Example: - >>> maybe_parse("1") - Literal(this=1, is_string=False) - >>> maybe_parse(to_identifier("x")) - Identifier(this=x, quoted=False) - - Args: - sql_or_expression: the SQL code string or an expression - into: the SQLGlot Expression to parse into - dialect: the dialect used to parse the input expressions (in the case that an - input expression is a SQL string). - prefix: a string to prefix the sql with before it gets parsed - (automatically includes a space) - copy: whether to copy the expression. - **opts: other options to use to parse the input expressions (again, in the case - that an input expression is a SQL string). - - Returns: - Expression: the parsed or given expression. - """ - if isinstance(sql_or_expression, Expression): - if copy: - return sql_or_expression.copy() - return sql_or_expression - - if sql_or_expression is None: - raise ParseError("SQL cannot be None") - - import sqlglot - - sql = str(sql_or_expression) - if prefix: - sql = f"{prefix} {sql}" - - return sqlglot.parse_one(sql, read=dialect, into=into, **opts) - - -@t.overload -def maybe_copy(instance: None, copy: bool = True) -> None: ... - - -@t.overload -def maybe_copy(instance: E, copy: bool = True) -> E: ... - - -def maybe_copy(instance, copy=True): - return instance.copy() if copy and instance else instance - - -def _to_s(node: t.Any, verbose: bool = False, level: int = 0, repr_str: bool = False) -> str: - """Generate a textual representation of an Expression tree""" - indent = "\n" + (" " * (level + 1)) - delim = f",{indent}" - - if isinstance(node, Expression): - args = {k: v for k, v in node.args.items() if (v is not None and v != []) or verbose} - - if (node.type or verbose) and not isinstance(node, DataType): - args["_type"] = node.type - if node.comments or verbose: - args["_comments"] = node.comments - - if verbose: - args["_id"] = id(node) - - # Inline leaves for a more compact representation - if node.is_leaf(): - indent = "" - delim = ", " - - repr_str = node.is_string or (isinstance(node, Identifier) and node.quoted) - items = delim.join( - [f"{k}={_to_s(v, verbose, level + 1, repr_str=repr_str)}" for k, v in args.items()] - ) - return f"{node.__class__.__name__}({indent}{items})" - - if isinstance(node, list): - items = delim.join(_to_s(i, verbose, level + 1) for i in node) - items = f"{indent}{items}" if items else "" - return f"[{items}]" - - # We use the representation of the string to avoid stripping out important whitespace - if repr_str and isinstance(node, str): - node = repr(node) - - # Indent multiline strings to match the current level - return indent.join(textwrap.dedent(str(node).strip("\n")).splitlines()) - - -def _is_wrong_expression(expression, into): - return isinstance(expression, Expression) and not isinstance(expression, into) - - -def _apply_builder( - expression, - instance, - arg, - copy=True, - prefix=None, - into=None, - dialect=None, - into_arg="this", - **opts, -): - if _is_wrong_expression(expression, into): - expression = into(**{into_arg: expression}) - instance = maybe_copy(instance, copy) - expression = maybe_parse( - sql_or_expression=expression, - prefix=prefix, - into=into, - dialect=dialect, - **opts, - ) - instance.set(arg, expression) - return instance - - -def _apply_child_list_builder( - *expressions, - instance, - arg, - append=True, - copy=True, - prefix=None, - into=None, - dialect=None, - properties=None, - **opts, -): - instance = maybe_copy(instance, copy) - parsed = [] - properties = {} if properties is None else properties - - for expression in expressions: - if expression is not None: - if _is_wrong_expression(expression, into): - expression = into(expressions=[expression]) - - expression = maybe_parse( - expression, - into=into, - dialect=dialect, - prefix=prefix, - **opts, - ) - for k, v in expression.args.items(): - if k == "expressions": - parsed.extend(v) - else: - properties[k] = v - - existing = instance.args.get(arg) - if append and existing: - parsed = existing.expressions + parsed - - child = into(expressions=parsed) - for k, v in properties.items(): - child.set(k, v) - instance.set(arg, child) - - return instance - - -def _apply_list_builder( - *expressions, - instance, - arg, - append=True, - copy=True, - prefix=None, - into=None, - dialect=None, - **opts, -): - inst = maybe_copy(instance, copy) - - expressions = [ - maybe_parse( - sql_or_expression=expression, - into=into, - prefix=prefix, - dialect=dialect, - **opts, - ) - for expression in expressions - if expression is not None - ] - - existing_expressions = inst.args.get(arg) - if append and existing_expressions: - expressions = existing_expressions + expressions - - inst.set(arg, expressions) - return inst - - -def _apply_conjunction_builder( - *expressions, - instance, - arg, - into=None, - append=True, - copy=True, - dialect=None, - **opts, -): - expressions = [exp for exp in expressions if exp is not None and exp != ""] - if not expressions: - return instance - - inst = maybe_copy(instance, copy) - - existing = inst.args.get(arg) - if append and existing is not None: - expressions = [existing.this if into else existing] + list(expressions) - - node = and_(*expressions, dialect=dialect, copy=copy, **opts) - - inst.set(arg, into(this=node) if into else node) - return inst - - -def _apply_cte_builder( - instance: E, - alias: ExpOrStr, - as_: ExpOrStr, - recursive: t.Optional[bool] = None, - materialized: t.Optional[bool] = None, - append: bool = True, - dialect: DialectType = None, - copy: bool = True, - scalar: bool = False, - **opts, -) -> E: - alias_expression = maybe_parse(alias, dialect=dialect, into=TableAlias, **opts) - as_expression = maybe_parse(as_, dialect=dialect, copy=copy, **opts) - if scalar and not isinstance(as_expression, Subquery): - # scalar CTE must be wrapped in a subquery - as_expression = Subquery(this=as_expression) - cte = CTE(this=as_expression, alias=alias_expression, materialized=materialized, scalar=scalar) - return _apply_child_list_builder( - cte, - instance=instance, - arg="with", - append=append, - copy=copy, - into=With, - properties={"recursive": recursive or False}, - ) - - -def _combine( - expressions: t.Sequence[t.Optional[ExpOrStr]], - operator: t.Type[Connector], - dialect: DialectType = None, - copy: bool = True, - wrap: bool = True, - **opts, -) -> Expression: - conditions = [ - condition(expression, dialect=dialect, copy=copy, **opts) - for expression in expressions - if expression is not None - ] - - this, *rest = conditions - if rest and wrap: - this = _wrap(this, Connector) - for expression in rest: - this = operator(this=this, expression=_wrap(expression, Connector) if wrap else expression) - - return this - - -@t.overload -def _wrap(expression: None, kind: t.Type[Expression]) -> None: ... - - -@t.overload -def _wrap(expression: E, kind: t.Type[Expression]) -> E | Paren: ... - - -def _wrap(expression: t.Optional[E], kind: t.Type[Expression]) -> t.Optional[E] | Paren: - return Paren(this=expression) if isinstance(expression, kind) else expression - - -def _apply_set_operation( - *expressions: ExpOrStr, - set_operation: t.Type[S], - distinct: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, -) -> S: - return reduce( - lambda x, y: set_operation(this=x, expression=y, distinct=distinct, **opts), - (maybe_parse(e, dialect=dialect, copy=copy, **opts) for e in expressions), - ) - - -def union( - *expressions: ExpOrStr, - distinct: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, -) -> Union: - """ - Initializes a syntax tree for the `UNION` operation. - - Example: - >>> union("SELECT * FROM foo", "SELECT * FROM bla").sql() - 'SELECT * FROM foo UNION SELECT * FROM bla' - - Args: - expressions: the SQL code strings, corresponding to the `UNION`'s operands. - If `Expression` instances are passed, they will be used as-is. - distinct: set the DISTINCT flag if and only if this is true. - dialect: the dialect used to parse the input expression. - copy: whether to copy the expression. - opts: other options to use to parse the input expressions. - - Returns: - The new Union instance. - """ - assert len(expressions) >= 2, "At least two expressions are required by `union`." - return _apply_set_operation( - *expressions, set_operation=Union, distinct=distinct, dialect=dialect, copy=copy, **opts - ) - - -def intersect( - *expressions: ExpOrStr, - distinct: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, -) -> Intersect: - """ - Initializes a syntax tree for the `INTERSECT` operation. - - Example: - >>> intersect("SELECT * FROM foo", "SELECT * FROM bla").sql() - 'SELECT * FROM foo INTERSECT SELECT * FROM bla' - - Args: - expressions: the SQL code strings, corresponding to the `INTERSECT`'s operands. - If `Expression` instances are passed, they will be used as-is. - distinct: set the DISTINCT flag if and only if this is true. - dialect: the dialect used to parse the input expression. - copy: whether to copy the expression. - opts: other options to use to parse the input expressions. - - Returns: - The new Intersect instance. - """ - assert len(expressions) >= 2, "At least two expressions are required by `intersect`." - return _apply_set_operation( - *expressions, set_operation=Intersect, distinct=distinct, dialect=dialect, copy=copy, **opts - ) - - -def except_( - *expressions: ExpOrStr, - distinct: bool = True, - dialect: DialectType = None, - copy: bool = True, - **opts, -) -> Except: - """ - Initializes a syntax tree for the `EXCEPT` operation. - - Example: - >>> except_("SELECT * FROM foo", "SELECT * FROM bla").sql() - 'SELECT * FROM foo EXCEPT SELECT * FROM bla' - - Args: - expressions: the SQL code strings, corresponding to the `EXCEPT`'s operands. - If `Expression` instances are passed, they will be used as-is. - distinct: set the DISTINCT flag if and only if this is true. - dialect: the dialect used to parse the input expression. - copy: whether to copy the expression. - opts: other options to use to parse the input expressions. - - Returns: - The new Except instance. - """ - assert len(expressions) >= 2, "At least two expressions are required by `except_`." - return _apply_set_operation( - *expressions, set_operation=Except, distinct=distinct, dialect=dialect, copy=copy, **opts - ) - - -def select(*expressions: ExpOrStr, dialect: DialectType = None, **opts) -> Select: - """ - Initializes a syntax tree from one or multiple SELECT expressions. - - Example: - >>> select("col1", "col2").from_("tbl").sql() - 'SELECT col1, col2 FROM tbl' - - Args: - *expressions: the SQL code string to parse as the expressions of a - SELECT statement. If an Expression instance is passed, this is used as-is. - dialect: the dialect used to parse the input expressions (in the case that an - input expression is a SQL string). - **opts: other options to use to parse the input expressions (again, in the case - that an input expression is a SQL string). - - Returns: - Select: the syntax tree for the SELECT statement. - """ - return Select().select(*expressions, dialect=dialect, **opts) - - -def from_(expression: ExpOrStr, dialect: DialectType = None, **opts) -> Select: - """ - Initializes a syntax tree from a FROM expression. - - Example: - >>> from_("tbl").select("col1", "col2").sql() - 'SELECT col1, col2 FROM tbl' - - Args: - *expression: the SQL code string to parse as the FROM expressions of a - SELECT statement. If an Expression instance is passed, this is used as-is. - dialect: the dialect used to parse the input expression (in the case that the - input expression is a SQL string). - **opts: other options to use to parse the input expressions (again, in the case - that the input expression is a SQL string). - - Returns: - Select: the syntax tree for the SELECT statement. - """ - return Select().from_(expression, dialect=dialect, **opts) - - -def update( - table: str | Table, - properties: t.Optional[dict] = None, - where: t.Optional[ExpOrStr] = None, - from_: t.Optional[ExpOrStr] = None, - with_: t.Optional[t.Dict[str, ExpOrStr]] = None, - dialect: DialectType = None, - **opts, -) -> Update: - """ - Creates an update statement. - - Example: - >>> update("my_table", {"x": 1, "y": "2", "z": None}, from_="baz_cte", where="baz_cte.id > 1 and my_table.id = baz_cte.id", with_={"baz_cte": "SELECT id FROM foo"}).sql() - "WITH baz_cte AS (SELECT id FROM foo) UPDATE my_table SET x = 1, y = '2', z = NULL FROM baz_cte WHERE baz_cte.id > 1 AND my_table.id = baz_cte.id" - - Args: - properties: dictionary of properties to SET which are - auto converted to sql objects eg None -> NULL - where: sql conditional parsed into a WHERE statement - from_: sql statement parsed into a FROM statement - with_: dictionary of CTE aliases / select statements to include in a WITH clause. - dialect: the dialect used to parse the input expressions. - **opts: other options to use to parse the input expressions. - - Returns: - Update: the syntax tree for the UPDATE statement. - """ - update_expr = Update(this=maybe_parse(table, into=Table, dialect=dialect)) - if properties: - update_expr.set( - "expressions", - [ - EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) - for k, v in properties.items() - ], - ) - if from_: - update_expr.set( - "from", - maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts), - ) - if isinstance(where, Condition): - where = Where(this=where) - if where: - update_expr.set( - "where", - maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts), - ) - if with_: - cte_list = [ - alias_(CTE(this=maybe_parse(qry, dialect=dialect, **opts)), alias, table=True) - for alias, qry in with_.items() - ] - update_expr.set( - "with", - With(expressions=cte_list), - ) - return update_expr - - -def delete( - table: ExpOrStr, - where: t.Optional[ExpOrStr] = None, - returning: t.Optional[ExpOrStr] = None, - dialect: DialectType = None, - **opts, -) -> Delete: - """ - Builds a delete statement. - - Example: - >>> delete("my_table", where="id > 1").sql() - 'DELETE FROM my_table WHERE id > 1' - - Args: - where: sql conditional parsed into a WHERE statement - returning: sql conditional parsed into a RETURNING statement - dialect: the dialect used to parse the input expressions. - **opts: other options to use to parse the input expressions. - - Returns: - Delete: the syntax tree for the DELETE statement. - """ - delete_expr = Delete().delete(table, dialect=dialect, copy=False, **opts) - if where: - delete_expr = delete_expr.where(where, dialect=dialect, copy=False, **opts) - if returning: - delete_expr = delete_expr.returning(returning, dialect=dialect, copy=False, **opts) - return delete_expr - - -def insert( - expression: ExpOrStr, - into: ExpOrStr, - columns: t.Optional[t.Sequence[str | Identifier]] = None, - overwrite: t.Optional[bool] = None, - returning: t.Optional[ExpOrStr] = None, - dialect: DialectType = None, - copy: bool = True, - **opts, -) -> Insert: - """ - Builds an INSERT statement. - - Example: - >>> insert("VALUES (1, 2, 3)", "tbl").sql() - 'INSERT INTO tbl VALUES (1, 2, 3)' - - Args: - expression: the sql string or expression of the INSERT statement - into: the tbl to insert data to. - columns: optionally the table's column names. - overwrite: whether to INSERT OVERWRITE or not. - returning: sql conditional parsed into a RETURNING statement - dialect: the dialect used to parse the input expressions. - copy: whether to copy the expression. - **opts: other options to use to parse the input expressions. - - Returns: - Insert: the syntax tree for the INSERT statement. - """ - expr = maybe_parse(expression, dialect=dialect, copy=copy, **opts) - this: Table | Schema = maybe_parse(into, into=Table, dialect=dialect, copy=copy, **opts) - - if columns: - this = Schema(this=this, expressions=[to_identifier(c, copy=copy) for c in columns]) - - insert = Insert(this=this, expression=expr, overwrite=overwrite) - - if returning: - insert = insert.returning(returning, dialect=dialect, copy=False, **opts) - - return insert - - -def merge( - *when_exprs: ExpOrStr, - into: ExpOrStr, - using: ExpOrStr, - on: ExpOrStr, - returning: t.Optional[ExpOrStr] = None, - dialect: DialectType = None, - copy: bool = True, - **opts, -) -> Merge: - """ - Builds a MERGE statement. - - Example: - >>> merge("WHEN MATCHED THEN UPDATE SET col1 = source_table.col1", - ... "WHEN NOT MATCHED THEN INSERT (col1) VALUES (source_table.col1)", - ... into="my_table", - ... using="source_table", - ... on="my_table.id = source_table.id").sql() - 'MERGE INTO my_table USING source_table ON my_table.id = source_table.id WHEN MATCHED THEN UPDATE SET col1 = source_table.col1 WHEN NOT MATCHED THEN INSERT (col1) VALUES (source_table.col1)' - - Args: - *when_exprs: The WHEN clauses specifying actions for matched and unmatched rows. - into: The target table to merge data into. - using: The source table to merge data from. - on: The join condition for the merge. - returning: The columns to return from the merge. - dialect: The dialect used to parse the input expressions. - copy: Whether to copy the expression. - **opts: Other options to use to parse the input expressions. - - Returns: - Merge: The syntax tree for the MERGE statement. - """ - expressions: t.List[Expression] = [] - for when_expr in when_exprs: - expression = maybe_parse(when_expr, dialect=dialect, copy=copy, into=Whens, **opts) - expressions.extend([expression] if isinstance(expression, When) else expression.expressions) - - merge = Merge( - this=maybe_parse(into, dialect=dialect, copy=copy, **opts), - using=maybe_parse(using, dialect=dialect, copy=copy, **opts), - on=maybe_parse(on, dialect=dialect, copy=copy, **opts), - whens=Whens(expressions=expressions), - ) - if returning: - merge = merge.returning(returning, dialect=dialect, copy=False, **opts) - - return merge - - -def condition( - expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts -) -> Condition: - """ - Initialize a logical condition expression. - - Example: - >>> condition("x=1").sql() - 'x = 1' - - This is helpful for composing larger logical syntax trees: - >>> where = condition("x=1") - >>> where = where.and_("y=1") - >>> Select().from_("tbl").select("*").where(where).sql() - 'SELECT * FROM tbl WHERE x = 1 AND y = 1' - - Args: - *expression: the SQL code string to parse. - If an Expression instance is passed, this is used as-is. - dialect: the dialect used to parse the input expression (in the case that the - input expression is a SQL string). - copy: Whether to copy `expression` (only applies to expressions). - **opts: other options to use to parse the input expressions (again, in the case - that the input expression is a SQL string). - - Returns: - The new Condition instance - """ - return maybe_parse( - expression, - into=Condition, - dialect=dialect, - copy=copy, - **opts, - ) - - -def and_( - *expressions: t.Optional[ExpOrStr], - dialect: DialectType = None, - copy: bool = True, - wrap: bool = True, - **opts, -) -> Condition: - """ - Combine multiple conditions with an AND logical operator. - - Example: - >>> and_("x=1", and_("y=1", "z=1")).sql() - 'x = 1 AND (y = 1 AND z = 1)' - - Args: - *expressions: the SQL code strings to parse. - If an Expression instance is passed, this is used as-is. - dialect: the dialect used to parse the input expression. - copy: whether to copy `expressions` (only applies to Expressions). - wrap: whether to wrap the operands in `Paren`s. This is true by default to avoid - precedence issues, but can be turned off when the produced AST is too deep and - causes recursion-related issues. - **opts: other options to use to parse the input expressions. - - Returns: - The new condition - """ - return t.cast(Condition, _combine(expressions, And, dialect, copy=copy, wrap=wrap, **opts)) - - -def or_( - *expressions: t.Optional[ExpOrStr], - dialect: DialectType = None, - copy: bool = True, - wrap: bool = True, - **opts, -) -> Condition: - """ - Combine multiple conditions with an OR logical operator. - - Example: - >>> or_("x=1", or_("y=1", "z=1")).sql() - 'x = 1 OR (y = 1 OR z = 1)' - - Args: - *expressions: the SQL code strings to parse. - If an Expression instance is passed, this is used as-is. - dialect: the dialect used to parse the input expression. - copy: whether to copy `expressions` (only applies to Expressions). - wrap: whether to wrap the operands in `Paren`s. This is true by default to avoid - precedence issues, but can be turned off when the produced AST is too deep and - causes recursion-related issues. - **opts: other options to use to parse the input expressions. - - Returns: - The new condition - """ - return t.cast(Condition, _combine(expressions, Or, dialect, copy=copy, wrap=wrap, **opts)) - - -def xor( - *expressions: t.Optional[ExpOrStr], - dialect: DialectType = None, - copy: bool = True, - wrap: bool = True, - **opts, -) -> Condition: - """ - Combine multiple conditions with an XOR logical operator. - - Example: - >>> xor("x=1", xor("y=1", "z=1")).sql() - 'x = 1 XOR (y = 1 XOR z = 1)' - - Args: - *expressions: the SQL code strings to parse. - If an Expression instance is passed, this is used as-is. - dialect: the dialect used to parse the input expression. - copy: whether to copy `expressions` (only applies to Expressions). - wrap: whether to wrap the operands in `Paren`s. This is true by default to avoid - precedence issues, but can be turned off when the produced AST is too deep and - causes recursion-related issues. - **opts: other options to use to parse the input expressions. - - Returns: - The new condition - """ - return t.cast(Condition, _combine(expressions, Xor, dialect, copy=copy, wrap=wrap, **opts)) - - -def not_(expression: ExpOrStr, dialect: DialectType = None, copy: bool = True, **opts) -> Not: - """ - Wrap a condition with a NOT operator. - - Example: - >>> not_("this_suit='black'").sql() - "NOT this_suit = 'black'" - - Args: - expression: the SQL code string to parse. - If an Expression instance is passed, this is used as-is. - dialect: the dialect used to parse the input expression. - copy: whether to copy the expression or not. - **opts: other options to use to parse the input expressions. - - Returns: - The new condition. - """ - this = condition( - expression, - dialect=dialect, - copy=copy, - **opts, - ) - return Not(this=_wrap(this, Connector)) - - -def paren(expression: ExpOrStr, copy: bool = True) -> Paren: - """ - Wrap an expression in parentheses. - - Example: - >>> paren("5 + 3").sql() - '(5 + 3)' - - Args: - expression: the SQL code string to parse. - If an Expression instance is passed, this is used as-is. - copy: whether to copy the expression or not. - - Returns: - The wrapped expression. - """ - return Paren(this=maybe_parse(expression, copy=copy)) - - -SAFE_IDENTIFIER_RE: t.Pattern[str] = re.compile(r"^[_a-zA-Z][\w]*$") - - -@t.overload -def to_identifier(name: None, quoted: t.Optional[bool] = None, copy: bool = True) -> None: ... - - -@t.overload -def to_identifier( - name: str | Identifier, quoted: t.Optional[bool] = None, copy: bool = True -) -> Identifier: ... - - -def to_identifier(name, quoted=None, copy=True): - """Builds an identifier. - - Args: - name: The name to turn into an identifier. - quoted: Whether to force quote the identifier. - copy: Whether to copy name if it's an Identifier. - - Returns: - The identifier ast node. - """ - - if name is None: - return None - - if isinstance(name, Identifier): - identifier = maybe_copy(name, copy) - elif isinstance(name, str): - identifier = Identifier( - this=name, - quoted=not SAFE_IDENTIFIER_RE.match(name) if quoted is None else quoted, - ) - else: - raise ValueError(f"Name needs to be a string or an Identifier, got: {name.__class__}") - return identifier - - -def parse_identifier(name: str | Identifier, dialect: DialectType = None) -> Identifier: - """ - Parses a given string into an identifier. - - Args: - name: The name to parse into an identifier. - dialect: The dialect to parse against. - - Returns: - The identifier ast node. - """ - try: - expression = maybe_parse(name, dialect=dialect, into=Identifier) - except (ParseError, TokenError): - expression = to_identifier(name) - - return expression - - -INTERVAL_STRING_RE = re.compile(r"\s*(-?[0-9]+(?:\.[0-9]+)?)\s*([a-zA-Z]+)\s*") - - -def to_interval(interval: str | Literal) -> Interval: - """Builds an interval expression from a string like '1 day' or '5 months'.""" - if isinstance(interval, Literal): - if not interval.is_string: - raise ValueError("Invalid interval string.") - - interval = interval.this - - interval = maybe_parse(f"INTERVAL {interval}") - assert isinstance(interval, Interval) - return interval - - -def to_table( - sql_path: str | Table, dialect: DialectType = None, copy: bool = True, **kwargs -) -> Table: - """ - Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional. - If a table is passed in then that table is returned. - - Args: - sql_path: a `[catalog].[schema].[table]` string. - dialect: the source dialect according to which the table name will be parsed. - copy: Whether to copy a table if it is passed in. - kwargs: the kwargs to instantiate the resulting `Table` expression with. - - Returns: - A table expression. - """ - if isinstance(sql_path, Table): - return maybe_copy(sql_path, copy=copy) - - try: - table = maybe_parse(sql_path, into=Table, dialect=dialect) - except ParseError: - catalog, db, this = split_num_words(sql_path, ".", 3) - - if not this: - raise - - table = table_(this, db=db, catalog=catalog) - - for k, v in kwargs.items(): - table.set(k, v) - - return table - - -def to_column( - sql_path: str | Column, - quoted: t.Optional[bool] = None, - dialect: DialectType = None, - copy: bool = True, - **kwargs, -) -> Column: - """ - Create a column from a `[table].[column]` sql path. Table is optional. - If a column is passed in then that column is returned. - - Args: - sql_path: a `[table].[column]` string. - quoted: Whether or not to force quote identifiers. - dialect: the source dialect according to which the column name will be parsed. - copy: Whether to copy a column if it is passed in. - kwargs: the kwargs to instantiate the resulting `Column` expression with. - - Returns: - A column expression. - """ - if isinstance(sql_path, Column): - return maybe_copy(sql_path, copy=copy) - - try: - col = maybe_parse(sql_path, into=Column, dialect=dialect) - except ParseError: - return column(*reversed(sql_path.split(".")), quoted=quoted, **kwargs) - - for k, v in kwargs.items(): - col.set(k, v) - - if quoted: - for i in col.find_all(Identifier): - i.set("quoted", True) - - return col - - -def alias_( - expression: ExpOrStr, - alias: t.Optional[str | Identifier], - table: bool | t.Sequence[str | Identifier] = False, - quoted: t.Optional[bool] = None, - dialect: DialectType = None, - copy: bool = True, - **opts, -): - """Create an Alias expression. - - Example: - >>> alias_('foo', 'bar').sql() - 'foo AS bar' - - >>> alias_('(select 1, 2)', 'bar', table=['a', 'b']).sql() - '(SELECT 1, 2) AS bar(a, b)' - - Args: - expression: the SQL code strings to parse. - If an Expression instance is passed, this is used as-is. - alias: the alias name to use. If the name has - special characters it is quoted. - table: Whether to create a table alias, can also be a list of columns. - quoted: whether to quote the alias - dialect: the dialect used to parse the input expression. - copy: Whether to copy the expression. - **opts: other options to use to parse the input expressions. - - Returns: - Alias: the aliased expression - """ - exp = maybe_parse(expression, dialect=dialect, copy=copy, **opts) - alias = to_identifier(alias, quoted=quoted) - - if table: - table_alias = TableAlias(this=alias) - exp.set("alias", table_alias) - - if not isinstance(table, bool): - for column in table: - table_alias.append("columns", to_identifier(column, quoted=quoted)) - - return exp - - # We don't set the "alias" arg for Window expressions, because that would add an IDENTIFIER node in - # the AST, representing a "named_window" [1] construct (eg. bigquery). What we want is an ALIAS node - # for the complete Window expression. - # - # [1]: https://cloud.google.com/bigquery/docs/reference/standard-sql/window-function-calls - - if "alias" in exp.arg_types and not isinstance(exp, Window): - exp.set("alias", alias) - return exp - return Alias(this=exp, alias=alias) - - -def subquery( - expression: ExpOrStr, - alias: t.Optional[Identifier | str] = None, - dialect: DialectType = None, - **opts, -) -> Select: - """ - Build a subquery expression that's selected from. - - Example: - >>> subquery('select x from tbl', 'bar').select('x').sql() - 'SELECT x FROM (SELECT x FROM tbl) AS bar' - - Args: - expression: the SQL code strings to parse. - If an Expression instance is passed, this is used as-is. - alias: the alias name to use. - dialect: the dialect used to parse the input expression. - **opts: other options to use to parse the input expressions. - - Returns: - A new Select instance with the subquery expression included. - """ - - expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias, **opts) - return Select().from_(expression, dialect=dialect, **opts) - - -@t.overload -def column( - col: str | Identifier, - table: t.Optional[str | Identifier] = None, - db: t.Optional[str | Identifier] = None, - catalog: t.Optional[str | Identifier] = None, - *, - fields: t.Collection[t.Union[str, Identifier]], - quoted: t.Optional[bool] = None, - copy: bool = True, -) -> Dot: - pass - - -@t.overload -def column( - col: str | Identifier | Star, - table: t.Optional[str | Identifier] = None, - db: t.Optional[str | Identifier] = None, - catalog: t.Optional[str | Identifier] = None, - *, - fields: Lit[None] = None, - quoted: t.Optional[bool] = None, - copy: bool = True, -) -> Column: - pass - - -def column( - col, - table=None, - db=None, - catalog=None, - *, - fields=None, - quoted=None, - copy=True, -): - """ - Build a Column. - - Args: - col: Column name. - table: Table name. - db: Database name. - catalog: Catalog name. - fields: Additional fields using dots. - quoted: Whether to force quotes on the column's identifiers. - copy: Whether to copy identifiers if passed in. - - Returns: - The new Column instance. - """ - if not isinstance(col, Star): - col = to_identifier(col, quoted=quoted, copy=copy) - - this = Column( - this=col, - table=to_identifier(table, quoted=quoted, copy=copy), - db=to_identifier(db, quoted=quoted, copy=copy), - catalog=to_identifier(catalog, quoted=quoted, copy=copy), - ) - - if fields: - this = Dot.build( - (this, *(to_identifier(field, quoted=quoted, copy=copy) for field in fields)) - ) - return this - - -def cast( - expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, dialect: DialectType = None, **opts -) -> Cast: - """Cast an expression to a data type. - - Example: - >>> cast('x + 1', 'int').sql() - 'CAST(x + 1 AS INT)' - - Args: - expression: The expression to cast. - to: The datatype to cast to. - copy: Whether to copy the supplied expressions. - dialect: The target dialect. This is used to prevent a re-cast in the following scenario: - - The expression to be cast is already a exp.Cast expression - - The existing cast is to a type that is logically equivalent to new type - - For example, if :expression='CAST(x as DATETIME)' and :to=Type.TIMESTAMP, - but in the target dialect DATETIME is mapped to TIMESTAMP, then we will NOT return `CAST(x (as DATETIME) as TIMESTAMP)` - and instead just return the original expression `CAST(x as DATETIME)`. - - This is to prevent it being output as a double cast `CAST(x (as TIMESTAMP) as TIMESTAMP)` once the DATETIME -> TIMESTAMP - mapping is applied in the target dialect generator. - - Returns: - The new Cast instance. - """ - expr = maybe_parse(expression, copy=copy, dialect=dialect, **opts) - data_type = DataType.build(to, copy=copy, dialect=dialect, **opts) - - # dont re-cast if the expression is already a cast to the correct type - if isinstance(expr, Cast): - from sqlglot.dialects.dialect import Dialect - - target_dialect = Dialect.get_or_raise(dialect) - type_mapping = target_dialect.generator_class.TYPE_MAPPING - - existing_cast_type: DataType.Type = expr.to.this - new_cast_type: DataType.Type = data_type.this - types_are_equivalent = type_mapping.get( - existing_cast_type, existing_cast_type.value - ) == type_mapping.get(new_cast_type, new_cast_type.value) - - if expr.is_type(data_type) or types_are_equivalent: - return expr - - expr = Cast(this=expr, to=data_type) - expr.type = data_type - - return expr - - -def table_( - table: Identifier | str, - db: t.Optional[Identifier | str] = None, - catalog: t.Optional[Identifier | str] = None, - quoted: t.Optional[bool] = None, - alias: t.Optional[Identifier | str] = None, -) -> Table: - """Build a Table. - - Args: - table: Table name. - db: Database name. - catalog: Catalog name. - quote: Whether to force quotes on the table's identifiers. - alias: Table's alias. - - Returns: - The new Table instance. - """ - return Table( - this=to_identifier(table, quoted=quoted) if table else None, - db=to_identifier(db, quoted=quoted) if db else None, - catalog=to_identifier(catalog, quoted=quoted) if catalog else None, - alias=TableAlias(this=to_identifier(alias)) if alias else None, - ) - - -def values( - values: t.Iterable[t.Tuple[t.Any, ...]], - alias: t.Optional[str] = None, - columns: t.Optional[t.Iterable[str] | t.Dict[str, DataType]] = None, -) -> Values: - """Build VALUES statement. - - Example: - >>> values([(1, '2')]).sql() - "VALUES (1, '2')" - - Args: - values: values statements that will be converted to SQL - alias: optional alias - columns: Optional list of ordered column names or ordered dictionary of column names to types. - If either are provided then an alias is also required. - - Returns: - Values: the Values expression object - """ - if columns and not alias: - raise ValueError("Alias is required when providing columns") - - return Values( - expressions=[convert(tup) for tup in values], - alias=( - TableAlias(this=to_identifier(alias), columns=[to_identifier(x) for x in columns]) - if columns - else (TableAlias(this=to_identifier(alias)) if alias else None) - ), - ) - - -def var(name: t.Optional[ExpOrStr]) -> Var: - """Build a SQL variable. - - Example: - >>> repr(var('x')) - 'Var(this=x)' - - >>> repr(var(column('x', table='y'))) - 'Var(this=x)' - - Args: - name: The name of the var or an expression who's name will become the var. - - Returns: - The new variable node. - """ - if not name: - raise ValueError("Cannot convert empty name into var.") - - if isinstance(name, Expression): - name = name.name - return Var(this=name) - - -def rename_table( - old_name: str | Table, - new_name: str | Table, - dialect: DialectType = None, -) -> Alter: - """Build ALTER TABLE... RENAME... expression - - Args: - old_name: The old name of the table - new_name: The new name of the table - dialect: The dialect to parse the table. - - Returns: - Alter table expression - """ - old_table = to_table(old_name, dialect=dialect) - new_table = to_table(new_name, dialect=dialect) - return Alter( - this=old_table, - kind="TABLE", - actions=[ - AlterRename(this=new_table), - ], - ) - - -def rename_column( - table_name: str | Table, - old_column_name: str | Column, - new_column_name: str | Column, - exists: t.Optional[bool] = None, - dialect: DialectType = None, -) -> Alter: - """Build ALTER TABLE... RENAME COLUMN... expression - - Args: - table_name: Name of the table - old_column: The old name of the column - new_column: The new name of the column - exists: Whether to add the `IF EXISTS` clause - dialect: The dialect to parse the table/column. - - Returns: - Alter table expression - """ - table = to_table(table_name, dialect=dialect) - old_column = to_column(old_column_name, dialect=dialect) - new_column = to_column(new_column_name, dialect=dialect) - return Alter( - this=table, - kind="TABLE", - actions=[ - RenameColumn(this=old_column, to=new_column, exists=exists), - ], - ) - - -def convert(value: t.Any, copy: bool = False) -> Expression: - """Convert a python value into an expression object. - - Raises an error if a conversion is not possible. - - Args: - value: A python object. - copy: Whether to copy `value` (only applies to Expressions and collections). - - Returns: - The equivalent expression object. - """ - if isinstance(value, Expression): - return maybe_copy(value, copy) - if isinstance(value, str): - return Literal.string(value) - if isinstance(value, bool): - return Boolean(this=value) - if value is None or (isinstance(value, float) and math.isnan(value)): - return null() - if isinstance(value, numbers.Number): - return Literal.number(value) - if isinstance(value, bytes): - return HexString(this=value.hex()) - if isinstance(value, datetime.datetime): - datetime_literal = Literal.string(value.isoformat(sep=" ")) - - tz = None - if value.tzinfo: - # this works for zoneinfo.ZoneInfo, pytz.timezone and datetime.datetime.utc to return IANA timezone names like "America/Los_Angeles" - # instead of abbreviations like "PDT". This is for consistency with other timezone handling functions in SQLGlot - tz = Literal.string(str(value.tzinfo)) - - return TimeStrToTime(this=datetime_literal, zone=tz) - if isinstance(value, datetime.date): - date_literal = Literal.string(value.strftime("%Y-%m-%d")) - return DateStrToDate(this=date_literal) - if isinstance(value, tuple): - if hasattr(value, "_fields"): - return Struct( - expressions=[ - PropertyEQ( - this=to_identifier(k), expression=convert(getattr(value, k), copy=copy) - ) - for k in value._fields - ] - ) - return Tuple(expressions=[convert(v, copy=copy) for v in value]) - if isinstance(value, list): - return Array(expressions=[convert(v, copy=copy) for v in value]) - if isinstance(value, dict): - return Map( - keys=Array(expressions=[convert(k, copy=copy) for k in value]), - values=Array(expressions=[convert(v, copy=copy) for v in value.values()]), - ) - if hasattr(value, "__dict__"): - return Struct( - expressions=[ - PropertyEQ(this=to_identifier(k), expression=convert(v, copy=copy)) - for k, v in value.__dict__.items() - ] - ) - raise ValueError(f"Cannot convert {value}") - - -def replace_children(expression: Expression, fun: t.Callable, *args, **kwargs) -> None: - """ - Replace children of an expression with the result of a lambda fun(child) -> exp. - """ - for k, v in tuple(expression.args.items()): - is_list_arg = type(v) is list - - child_nodes = v if is_list_arg else [v] - new_child_nodes = [] - - for cn in child_nodes: - if isinstance(cn, Expression): - for child_node in ensure_collection(fun(cn, *args, **kwargs)): - new_child_nodes.append(child_node) - else: - new_child_nodes.append(cn) - - expression.set(k, new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)) - - -def replace_tree( - expression: Expression, - fun: t.Callable, - prune: t.Optional[t.Callable[[Expression], bool]] = None, -) -> Expression: - """ - Replace an entire tree with the result of function calls on each node. - - This will be traversed in reverse dfs, so leaves first. - If new nodes are created as a result of function calls, they will also be traversed. - """ - stack = list(expression.dfs(prune=prune)) - - while stack: - node = stack.pop() - new_node = fun(node) - - if new_node is not node: - node.replace(new_node) - - if isinstance(new_node, Expression): - stack.append(new_node) - - return new_node - - -def column_table_names(expression: Expression, exclude: str = "") -> t.Set[str]: - """ - Return all table names referenced through columns in an expression. - - Example: - >>> import sqlglot - >>> sorted(column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e"))) - ['a', 'c'] - - Args: - expression: expression to find table names. - exclude: a table name to exclude - - Returns: - A list of unique names. - """ - return { - table - for table in (column.table for column in expression.find_all(Column)) - if table and table != exclude - } - - -def table_name(table: Table | str, dialect: DialectType = None, identify: bool = False) -> str: - """Get the full name of a table as a string. - - Args: - table: Table expression node or string. - dialect: The dialect to generate the table name for. - identify: Determines when an identifier should be quoted. Possible values are: - False (default): Never quote, except in cases where it's mandatory by the dialect. - True: Always quote. - - Examples: - >>> from sqlglot import exp, parse_one - >>> table_name(parse_one("select * from a.b.c").find(exp.Table)) - 'a.b.c' - - Returns: - The table name. - """ - - table = maybe_parse(table, into=Table, dialect=dialect) - - if not table: - raise ValueError(f"Cannot parse {table}") - - return ".".join( - ( - part.sql(dialect=dialect, identify=True, copy=False, comments=False) - if identify or not SAFE_IDENTIFIER_RE.match(part.name) - else part.name - ) - for part in table.parts - ) - - -def normalize_table_name(table: str | Table, dialect: DialectType = None, copy: bool = True) -> str: - """Returns a case normalized table name without quotes. - - Args: - table: the table to normalize - dialect: the dialect to use for normalization rules - copy: whether to copy the expression. - - Examples: - >>> normalize_table_name("`A-B`.c", dialect="bigquery") - 'A-B.c' - """ - from sqlglot.optimizer.normalize_identifiers import normalize_identifiers - - return ".".join( - p.name - for p in normalize_identifiers( - to_table(table, dialect=dialect, copy=copy), dialect=dialect - ).parts - ) - - -def replace_tables( - expression: E, mapping: t.Dict[str, str], dialect: DialectType = None, copy: bool = True -) -> E: - """Replace all tables in expression according to the mapping. - - Args: - expression: expression node to be transformed and replaced. - mapping: mapping of table names. - dialect: the dialect of the mapping table - copy: whether to copy the expression. - - Examples: - >>> from sqlglot import exp, parse_one - >>> replace_tables(parse_one("select * from a.b"), {"a.b": "c"}).sql() - 'SELECT * FROM c /* a.b */' - - Returns: - The mapped expression. - """ - - mapping = {normalize_table_name(k, dialect=dialect): v for k, v in mapping.items()} - - def _replace_tables(node: Expression) -> Expression: - if isinstance(node, Table) and node.meta.get("replace") is not False: - original = normalize_table_name(node, dialect=dialect) - new_name = mapping.get(original) - - if new_name: - table = to_table( - new_name, - **{k: v for k, v in node.args.items() if k not in TABLE_PARTS}, - dialect=dialect, - ) - table.add_comments([original]) - return table - return node - - return expression.transform(_replace_tables, copy=copy) # type: ignore - - -def replace_placeholders(expression: Expression, *args, **kwargs) -> Expression: - """Replace placeholders in an expression. - - Args: - expression: expression node to be transformed and replaced. - args: positional names that will substitute unnamed placeholders in the given order. - kwargs: keyword arguments that will substitute named placeholders. - - Examples: - >>> from sqlglot import exp, parse_one - >>> replace_placeholders( - ... parse_one("select * from :tbl where ? = ?"), - ... exp.to_identifier("str_col"), "b", tbl=exp.to_identifier("foo") - ... ).sql() - "SELECT * FROM foo WHERE str_col = 'b'" - - Returns: - The mapped expression. - """ - - def _replace_placeholders(node: Expression, args, **kwargs) -> Expression: - if isinstance(node, Placeholder): - if node.this: - new_name = kwargs.get(node.this) - if new_name is not None: - return convert(new_name) - else: - try: - return convert(next(args)) - except StopIteration: - pass - return node - - return expression.transform(_replace_placeholders, iter(args), **kwargs) - - -def expand( - expression: Expression, - sources: t.Dict[str, Query | t.Callable[[], Query]], - dialect: DialectType = None, - copy: bool = True, -) -> Expression: - """Transforms an expression by expanding all referenced sources into subqueries. - - Examples: - >>> from sqlglot import parse_one - >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y")}).sql() - 'SELECT * FROM (SELECT * FROM y) AS z /* source: x */' - - >>> expand(parse_one("select * from x AS z"), {"x": parse_one("select * from y"), "y": parse_one("select * from z")}).sql() - 'SELECT * FROM (SELECT * FROM (SELECT * FROM z) AS y /* source: y */) AS z /* source: x */' - - Args: - expression: The expression to expand. - sources: A dict of name to query or a callable that provides a query on demand. - dialect: The dialect of the sources dict or the callable. - copy: Whether to copy the expression during transformation. Defaults to True. - - Returns: - The transformed expression. - """ - normalized_sources = {normalize_table_name(k, dialect=dialect): v for k, v in sources.items()} - - def _expand(node: Expression): - if isinstance(node, Table): - name = normalize_table_name(node, dialect=dialect) - source = normalized_sources.get(name) - - if source: - # Create a subquery with the same alias (or table name if no alias) - parsed_source = source() if callable(source) else source - subquery = parsed_source.subquery(node.alias or name) - subquery.comments = [f"source: {name}"] - - # Continue expanding within the subquery - return subquery.transform(_expand, copy=False) - - return node - - return expression.transform(_expand, copy=copy) - - -def func(name: str, *args, copy: bool = True, dialect: DialectType = None, **kwargs) -> Func: - """ - Returns a Func expression. - - Examples: - >>> func("abs", 5).sql() - 'ABS(5)' - - >>> func("cast", this=5, to=DataType.build("DOUBLE")).sql() - 'CAST(5 AS DOUBLE)' - - Args: - name: the name of the function to build. - args: the args used to instantiate the function of interest. - copy: whether to copy the argument expressions. - dialect: the source dialect. - kwargs: the kwargs used to instantiate the function of interest. - - Note: - The arguments `args` and `kwargs` are mutually exclusive. - - Returns: - An instance of the function of interest, or an anonymous function, if `name` doesn't - correspond to an existing `sqlglot.expressions.Func` class. - """ - if args and kwargs: - raise ValueError("Can't use both args and kwargs to instantiate a function.") - - from sqlglot.dialects.dialect import Dialect - - dialect = Dialect.get_or_raise(dialect) - - converted: t.List[Expression] = [maybe_parse(arg, dialect=dialect, copy=copy) for arg in args] - kwargs = {key: maybe_parse(value, dialect=dialect, copy=copy) for key, value in kwargs.items()} - - constructor = dialect.parser_class.FUNCTIONS.get(name.upper()) - if constructor: - if converted: - if "dialect" in constructor.__code__.co_varnames: - function = constructor(converted, dialect=dialect) - else: - function = constructor(converted) - elif constructor.__name__ == "from_arg_list": - function = constructor.__self__(**kwargs) # type: ignore - else: - constructor = FUNCTION_BY_NAME.get(name.upper()) - if constructor: - function = constructor(**kwargs) - else: - raise ValueError( - f"Unable to convert '{name}' into a Func. Either manually construct " - "the Func expression of interest or parse the function call." - ) - else: - kwargs = kwargs or {"expressions": converted} - function = Anonymous(this=name, **kwargs) - - for error_message in function.error_messages(converted): - raise ValueError(error_message) - - return function - - -def case( - expression: t.Optional[ExpOrStr] = None, - **opts, -) -> Case: - """ - Initialize a CASE statement. - - Example: - case().when("a = 1", "foo").else_("bar") - - Args: - expression: Optionally, the input expression (not all dialects support this) - **opts: Extra keyword arguments for parsing `expression` - """ - if expression is not None: - this = maybe_parse(expression, **opts) - else: - this = None - return Case(this=this, ifs=[]) - - -def array( - *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs -) -> Array: - """ - Returns an array. - - Examples: - >>> array(1, 'x').sql() - 'ARRAY(1, x)' - - Args: - expressions: the expressions to add to the array. - copy: whether to copy the argument expressions. - dialect: the source dialect. - kwargs: the kwargs used to instantiate the function of interest. - - Returns: - An array expression. - """ - return Array( - expressions=[ - maybe_parse(expression, copy=copy, dialect=dialect, **kwargs) - for expression in expressions - ] - ) - - -def tuple_( - *expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs -) -> Tuple: - """ - Returns an tuple. - - Examples: - >>> tuple_(1, 'x').sql() - '(1, x)' - - Args: - expressions: the expressions to add to the tuple. - copy: whether to copy the argument expressions. - dialect: the source dialect. - kwargs: the kwargs used to instantiate the function of interest. - - Returns: - A tuple expression. - """ - return Tuple( - expressions=[ - maybe_parse(expression, copy=copy, dialect=dialect, **kwargs) - for expression in expressions - ] - ) - - -def true() -> Boolean: - """ - Returns a true Boolean expression. - """ - return Boolean(this=True) - - -def false() -> Boolean: - """ - Returns a false Boolean expression. - """ - return Boolean(this=False) - - -def null() -> Null: - """ - Returns a Null expression. - """ - return Null() - - -NONNULL_CONSTANTS = ( - Literal, - Boolean, -) - -CONSTANTS = ( - Literal, - Boolean, - Null, -) diff --git a/altimate_packages/sqlglot/generator.py b/altimate_packages/sqlglot/generator.py deleted file mode 100644 index 458256eb5..000000000 --- a/altimate_packages/sqlglot/generator.py +++ /dev/null @@ -1,4993 +0,0 @@ -from __future__ import annotations - -import logging -import re -import typing as t -from collections import defaultdict -from functools import reduce, wraps - -from sqlglot import exp -from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages -from sqlglot.helper import apply_index_offset, csv, name_sequence, seq_get -from sqlglot.jsonpath import ALL_JSON_PATH_PARTS, JSON_PATH_PART_TRANSFORMS -from sqlglot.time import format_time -from sqlglot.tokens import TokenType - -if t.TYPE_CHECKING: - from sqlglot._typing import E - from sqlglot.dialects.dialect import DialectType - - G = t.TypeVar("G", bound="Generator") - GeneratorMethod = t.Callable[[G, E], str] - -logger = logging.getLogger("sqlglot") - -ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)") -UNSUPPORTED_TEMPLATE = "Argument '{}' is not supported for expression '{}' when targeting {}." - - -def unsupported_args( - *args: t.Union[str, t.Tuple[str, str]], -) -> t.Callable[[GeneratorMethod], GeneratorMethod]: - """ - Decorator that can be used to mark certain args of an `Expression` subclass as unsupported. - It expects a sequence of argument names or pairs of the form (argument_name, diagnostic_msg). - """ - diagnostic_by_arg: t.Dict[str, t.Optional[str]] = {} - for arg in args: - if isinstance(arg, str): - diagnostic_by_arg[arg] = None - else: - diagnostic_by_arg[arg[0]] = arg[1] - - def decorator(func: GeneratorMethod) -> GeneratorMethod: - @wraps(func) - def _func(generator: G, expression: E) -> str: - expression_name = expression.__class__.__name__ - dialect_name = generator.dialect.__class__.__name__ - - for arg_name, diagnostic in diagnostic_by_arg.items(): - if expression.args.get(arg_name): - diagnostic = diagnostic or UNSUPPORTED_TEMPLATE.format( - arg_name, expression_name, dialect_name - ) - generator.unsupported(diagnostic) - - return func(generator, expression) - - return _func - - return decorator - - -class _Generator(type): - def __new__(cls, clsname, bases, attrs): - klass = super().__new__(cls, clsname, bases, attrs) - - # Remove transforms that correspond to unsupported JSONPathPart expressions - for part in ALL_JSON_PATH_PARTS - klass.SUPPORTED_JSON_PATH_PARTS: - klass.TRANSFORMS.pop(part, None) - - return klass - - -class Generator(metaclass=_Generator): - """ - Generator converts a given syntax tree to the corresponding SQL string. - - Args: - pretty: Whether to format the produced SQL string. - Default: False. - identify: Determines when an identifier should be quoted. Possible values are: - False (default): Never quote, except in cases where it's mandatory by the dialect. - True or 'always': Always quote. - 'safe': Only quote identifiers that are case insensitive. - normalize: Whether to normalize identifiers to lowercase. - Default: False. - pad: The pad size in a formatted string. For example, this affects the indentation of - a projection in a query, relative to its nesting level. - Default: 2. - indent: The indentation size in a formatted string. For example, this affects the - indentation of subqueries and filters under a `WHERE` clause. - Default: 2. - normalize_functions: How to normalize function names. Possible values are: - "upper" or True (default): Convert names to uppercase. - "lower": Convert names to lowercase. - False: Disables function name normalization. - unsupported_level: Determines the generator's behavior when it encounters unsupported expressions. - Default ErrorLevel.WARN. - max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError. - This is only relevant if unsupported_level is ErrorLevel.RAISE. - Default: 3 - leading_comma: Whether the comma is leading or trailing in select expressions. - This is only relevant when generating in pretty mode. - Default: False - max_text_width: The max number of characters in a segment before creating new lines in pretty mode. - The default is on the smaller end because the length only represents a segment and not the true - line length. - Default: 80 - comments: Whether to preserve comments in the output SQL code. - Default: True - """ - - TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = { - **JSON_PATH_PART_TRANSFORMS, - exp.AllowedValuesProperty: lambda self, - e: f"ALLOWED_VALUES {self.expressions(e, flat=True)}", - exp.AnalyzeColumns: lambda self, e: self.sql(e, "this"), - exp.AnalyzeWith: lambda self, e: self.expressions(e, prefix="WITH ", sep=" "), - exp.ArrayContainsAll: lambda self, e: self.binary(e, "@>"), - exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"), - exp.AutoRefreshProperty: lambda self, e: f"AUTO REFRESH {self.sql(e, 'this')}", - exp.BackupProperty: lambda self, e: f"BACKUP {self.sql(e, 'this')}", - exp.CaseSpecificColumnConstraint: lambda _, - e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC", - exp.Ceil: lambda self, e: self.ceil_floor(e), - exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}", - exp.CharacterSetProperty: lambda self, - e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}", - exp.ClusteredColumnConstraint: lambda self, - e: f"CLUSTERED ({self.expressions(e, 'this', indent=False)})", - exp.CollateColumnConstraint: lambda self, e: f"COLLATE {self.sql(e, 'this')}", - exp.CommentColumnConstraint: lambda self, e: f"COMMENT {self.sql(e, 'this')}", - exp.ConnectByRoot: lambda self, e: f"CONNECT_BY_ROOT {self.sql(e, 'this')}", - exp.ConvertToCharset: lambda self, e: self.func( - "CONVERT", e.this, e.args["dest"], e.args.get("source") - ), - exp.CopyGrantsProperty: lambda *_: "COPY GRANTS", - exp.CredentialsProperty: lambda self, - e: f"CREDENTIALS=({self.expressions(e, 'expressions', sep=' ')})", - exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}", - exp.DefaultColumnConstraint: lambda self, e: f"DEFAULT {self.sql(e, 'this')}", - exp.DynamicProperty: lambda *_: "DYNAMIC", - exp.EmptyProperty: lambda *_: "EMPTY", - exp.EncodeColumnConstraint: lambda self, e: f"ENCODE {self.sql(e, 'this')}", - exp.EnviromentProperty: lambda self, e: f"ENVIRONMENT ({self.expressions(e, flat=True)})", - exp.EphemeralColumnConstraint: lambda self, - e: f"EPHEMERAL{(' ' + self.sql(e, 'this')) if e.this else ''}", - exp.ExcludeColumnConstraint: lambda self, e: f"EXCLUDE {self.sql(e, 'this').lstrip()}", - exp.ExecuteAsProperty: lambda self, e: self.naked_property(e), - exp.Except: lambda self, e: self.set_operations(e), - exp.ExternalProperty: lambda *_: "EXTERNAL", - exp.Floor: lambda self, e: self.ceil_floor(e), - exp.Get: lambda self, e: self.get_put_sql(e), - exp.GlobalProperty: lambda *_: "GLOBAL", - exp.HeapProperty: lambda *_: "HEAP", - exp.IcebergProperty: lambda *_: "ICEBERG", - exp.InheritsProperty: lambda self, e: f"INHERITS ({self.expressions(e, flat=True)})", - exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", - exp.InputModelProperty: lambda self, e: f"INPUT{self.sql(e, 'this')}", - exp.Intersect: lambda self, e: self.set_operations(e), - exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}", - exp.Int64: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.BIGINT)), - exp.LanguageProperty: lambda self, e: self.naked_property(e), - exp.LocationProperty: lambda self, e: self.naked_property(e), - exp.LogProperty: lambda _, e: f"{'NO ' if e.args.get('no') else ''}LOG", - exp.MaterializedProperty: lambda *_: "MATERIALIZED", - exp.NonClusteredColumnConstraint: lambda self, - e: f"NONCLUSTERED ({self.expressions(e, 'this', indent=False)})", - exp.NoPrimaryIndexProperty: lambda *_: "NO PRIMARY INDEX", - exp.NotForReplicationColumnConstraint: lambda *_: "NOT FOR REPLICATION", - exp.OnCommitProperty: lambda _, - e: f"ON COMMIT {'DELETE' if e.args.get('delete') else 'PRESERVE'} ROWS", - exp.OnProperty: lambda self, e: f"ON {self.sql(e, 'this')}", - exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}", - exp.Operator: lambda self, e: self.binary(e, ""), # The operator is produced in `binary` - exp.OutputModelProperty: lambda self, e: f"OUTPUT{self.sql(e, 'this')}", - exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}", - exp.PartitionedByBucket: lambda self, e: self.func("BUCKET", e.this, e.expression), - exp.PartitionByTruncate: lambda self, e: self.func("TRUNCATE", e.this, e.expression), - exp.PivotAny: lambda self, e: f"ANY{self.sql(e, 'this')}", - exp.ProjectionPolicyColumnConstraint: lambda self, - e: f"PROJECTION POLICY {self.sql(e, 'this')}", - exp.Put: lambda self, e: self.get_put_sql(e), - exp.RemoteWithConnectionModelProperty: lambda self, - e: f"REMOTE WITH CONNECTION {self.sql(e, 'this')}", - exp.ReturnsProperty: lambda self, e: ( - "RETURNS NULL ON NULL INPUT" if e.args.get("null") else self.naked_property(e) - ), - exp.SampleProperty: lambda self, e: f"SAMPLE BY {self.sql(e, 'this')}", - exp.SecureProperty: lambda *_: "SECURE", - exp.SecurityProperty: lambda self, e: f"SECURITY {self.sql(e, 'this')}", - exp.SetConfigProperty: lambda self, e: self.sql(e, "this"), - exp.SetProperty: lambda _, e: f"{'MULTI' if e.args.get('multi') else ''}SET", - exp.SettingsProperty: lambda self, e: f"SETTINGS{self.seg('')}{(self.expressions(e))}", - exp.SharingProperty: lambda self, e: f"SHARING={self.sql(e, 'this')}", - exp.SqlReadWriteProperty: lambda _, e: e.name, - exp.SqlSecurityProperty: lambda _, - e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}", - exp.StabilityProperty: lambda _, e: e.name, - exp.Stream: lambda self, e: f"STREAM {self.sql(e, 'this')}", - exp.StreamingTableProperty: lambda *_: "STREAMING", - exp.StrictProperty: lambda *_: "STRICT", - exp.SwapTable: lambda self, e: f"SWAP WITH {self.sql(e, 'this')}", - exp.Tags: lambda self, e: f"TAG ({self.expressions(e, flat=True)})", - exp.TemporaryProperty: lambda *_: "TEMPORARY", - exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}", - exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}", - exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}", - exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions), - exp.TransientProperty: lambda *_: "TRANSIENT", - exp.Union: lambda self, e: self.set_operations(e), - exp.UnloggedProperty: lambda *_: "UNLOGGED", - exp.UsingTemplateProperty: lambda self, e: f"USING TEMPLATE {self.sql(e, 'this')}", - exp.UsingData: lambda self, e: f"USING DATA {self.sql(e, 'this')}", - exp.Uuid: lambda *_: "UUID()", - exp.UppercaseColumnConstraint: lambda *_: "UPPERCASE", - exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), - exp.ViewAttributeProperty: lambda self, e: f"WITH {self.sql(e, 'this')}", - exp.VolatileProperty: lambda *_: "VOLATILE", - exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}", - exp.WithProcedureOptions: lambda self, e: f"WITH {self.expressions(e, flat=True)}", - exp.WithSchemaBindingProperty: lambda self, e: f"WITH SCHEMA {self.sql(e, 'this')}", - exp.WithOperator: lambda self, e: f"{self.sql(e, 'this')} WITH {self.sql(e, 'op')}", - exp.ForceProperty: lambda *_: "FORCE", - } - - # Whether null ordering is supported in order by - # True: Full Support, None: No support, False: No support for certain cases - # such as window specifications, aggregate functions etc - NULL_ORDERING_SUPPORTED: t.Optional[bool] = True - - # Whether ignore nulls is inside the agg or outside. - # FIRST(x IGNORE NULLS) OVER vs FIRST (x) IGNORE NULLS OVER - IGNORE_NULLS_IN_FUNC = False - - # Whether locking reads (i.e. SELECT ... FOR UPDATE/SHARE) are supported - LOCKING_READS_SUPPORTED = False - - # Whether the EXCEPT and INTERSECT operations can return duplicates - EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE = True - - # Wrap derived values in parens, usually standard but spark doesn't support it - WRAP_DERIVED_VALUES = True - - # Whether create function uses an AS before the RETURN - CREATE_FUNCTION_RETURN_AS = True - - # Whether MERGE ... WHEN MATCHED BY SOURCE is allowed - MATCHED_BY_SOURCE = True - - # Whether the INTERVAL expression works only with values like '1 day' - SINGLE_STRING_INTERVAL = False - - # Whether the plural form of date parts like day (i.e. "days") is supported in INTERVALs - INTERVAL_ALLOWS_PLURAL_FORM = True - - # Whether limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") - LIMIT_FETCH = "ALL" - - # Whether limit and fetch allows expresions or just limits - LIMIT_ONLY_LITERALS = False - - # Whether a table is allowed to be renamed with a db - RENAME_TABLE_WITH_DB = True - - # The separator for grouping sets and rollups - GROUPINGS_SEP = "," - - # The string used for creating an index on a table - INDEX_ON = "ON" - - # Whether join hints should be generated - JOIN_HINTS = True - - # Whether table hints should be generated - TABLE_HINTS = True - - # Whether query hints should be generated - QUERY_HINTS = True - - # What kind of separator to use for query hints - QUERY_HINT_SEP = ", " - - # Whether comparing against booleans (e.g. x IS TRUE) is supported - IS_BOOL_ALLOWED = True - - # Whether to include the "SET" keyword in the "INSERT ... ON DUPLICATE KEY UPDATE" statement - DUPLICATE_KEY_UPDATE_WITH_SET = True - - # Whether to generate the limit as TOP instead of LIMIT - LIMIT_IS_TOP = False - - # Whether to generate INSERT INTO ... RETURNING or INSERT INTO RETURNING ... - RETURNING_END = True - - # Whether to generate an unquoted value for EXTRACT's date part argument - EXTRACT_ALLOWS_QUOTES = True - - # Whether TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax - TZ_TO_WITH_TIME_ZONE = False - - # Whether the NVL2 function is supported - NVL2_SUPPORTED = True - - # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax - SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE") - - # Whether VALUES statements can be used as derived tables. - # MySQL 5 and Redshift do not allow this, so when False, it will convert - # SELECT * VALUES into SELECT UNION - VALUES_AS_TABLE = True - - # Whether the word COLUMN is included when adding a column with ALTER TABLE - ALTER_TABLE_INCLUDE_COLUMN_KEYWORD = True - - # UNNEST WITH ORDINALITY (presto) instead of UNNEST WITH OFFSET (bigquery) - UNNEST_WITH_ORDINALITY = True - - # Whether FILTER (WHERE cond) can be used for conditional aggregation - AGGREGATE_FILTER_SUPPORTED = True - - # Whether JOIN sides (LEFT, RIGHT) are supported in conjunction with SEMI/ANTI join kinds - SEMI_ANTI_JOIN_WITH_SIDE = True - - # Whether to include the type of a computed column in the CREATE DDL - COMPUTED_COLUMN_WITH_TYPE = True - - # Whether CREATE TABLE .. COPY .. is supported. False means we'll generate CLONE instead of COPY - SUPPORTS_TABLE_COPY = True - - # Whether parentheses are required around the table sample's expression - TABLESAMPLE_REQUIRES_PARENS = True - - # Whether a table sample clause's size needs to be followed by the ROWS keyword - TABLESAMPLE_SIZE_IS_ROWS = True - - # The keyword(s) to use when generating a sample clause - TABLESAMPLE_KEYWORDS = "TABLESAMPLE" - - # Whether the TABLESAMPLE clause supports a method name, like BERNOULLI - TABLESAMPLE_WITH_METHOD = True - - # The keyword to use when specifying the seed of a sample clause - TABLESAMPLE_SEED_KEYWORD = "SEED" - - # Whether COLLATE is a function instead of a binary operator - COLLATE_IS_FUNC = False - - # Whether data types support additional specifiers like e.g. CHAR or BYTE (oracle) - DATA_TYPE_SPECIFIERS_ALLOWED = False - - # Whether conditions require booleans WHERE x = 0 vs WHERE x - ENSURE_BOOLS = False - - # Whether the "RECURSIVE" keyword is required when defining recursive CTEs - CTE_RECURSIVE_KEYWORD_REQUIRED = True - - # Whether CONCAT requires >1 arguments - SUPPORTS_SINGLE_ARG_CONCAT = True - - # Whether LAST_DAY function supports a date part argument - LAST_DAY_SUPPORTS_DATE_PART = True - - # Whether named columns are allowed in table aliases - SUPPORTS_TABLE_ALIAS_COLUMNS = True - - # Whether UNPIVOT aliases are Identifiers (False means they're Literals) - UNPIVOT_ALIASES_ARE_IDENTIFIERS = True - - # What delimiter to use for separating JSON key/value pairs - JSON_KEY_VALUE_PAIR_SEP = ":" - - # INSERT OVERWRITE TABLE x override - INSERT_OVERWRITE = " OVERWRITE TABLE" - - # Whether the SELECT .. INTO syntax is used instead of CTAS - SUPPORTS_SELECT_INTO = False - - # Whether UNLOGGED tables can be created - SUPPORTS_UNLOGGED_TABLES = False - - # Whether the CREATE TABLE LIKE statement is supported - SUPPORTS_CREATE_TABLE_LIKE = True - - # Whether the LikeProperty needs to be specified inside of the schema clause - LIKE_PROPERTY_INSIDE_SCHEMA = False - - # Whether DISTINCT can be followed by multiple args in an AggFunc. If not, it will be - # transpiled into a series of CASE-WHEN-ELSE, ultimately using a tuple conseisting of the args - MULTI_ARG_DISTINCT = True - - # Whether the JSON extraction operators expect a value of type JSON - JSON_TYPE_REQUIRED_FOR_EXTRACTION = False - - # Whether bracketed keys like ["foo"] are supported in JSON paths - JSON_PATH_BRACKETED_KEY_SUPPORTED = True - - # Whether to escape keys using single quotes in JSON paths - JSON_PATH_SINGLE_QUOTE_ESCAPE = False - - # The JSONPathPart expressions supported by this dialect - SUPPORTED_JSON_PATH_PARTS = ALL_JSON_PATH_PARTS.copy() - - # Whether any(f(x) for x in array) can be implemented by this dialect - CAN_IMPLEMENT_ARRAY_ANY = False - - # Whether the function TO_NUMBER is supported - SUPPORTS_TO_NUMBER = True - - # Whether EXCLUDE in window specification is supported - SUPPORTS_WINDOW_EXCLUDE = False - - # Whether or not set op modifiers apply to the outer set op or select. - # SELECT * FROM x UNION SELECT * FROM y LIMIT 1 - # True means limit 1 happens after the set op, False means it it happens on y. - SET_OP_MODIFIERS = True - - # Whether parameters from COPY statement are wrapped in parentheses - COPY_PARAMS_ARE_WRAPPED = True - - # Whether values of params are set with "=" token or empty space - COPY_PARAMS_EQ_REQUIRED = False - - # Whether COPY statement has INTO keyword - COPY_HAS_INTO_KEYWORD = True - - # Whether the conditional TRY(expression) function is supported - TRY_SUPPORTED = True - - # Whether the UESCAPE syntax in unicode strings is supported - SUPPORTS_UESCAPE = True - - # The keyword to use when generating a star projection with excluded columns - STAR_EXCEPT = "EXCEPT" - - # The HEX function name - HEX_FUNC = "HEX" - - # The keywords to use when prefixing & separating WITH based properties - WITH_PROPERTIES_PREFIX = "WITH" - - # Whether to quote the generated expression of exp.JsonPath - QUOTE_JSON_PATH = True - - # Whether the text pattern/fill (3rd) parameter of RPAD()/LPAD() is optional (defaults to space) - PAD_FILL_PATTERN_IS_REQUIRED = False - - # Whether a projection can explode into multiple rows, e.g. by unnesting an array. - SUPPORTS_EXPLODING_PROJECTIONS = True - - # Whether ARRAY_CONCAT can be generated with varlen args or if it should be reduced to 2-arg version - ARRAY_CONCAT_IS_VAR_LEN = True - - # Whether CONVERT_TIMEZONE() is supported; if not, it will be generated as exp.AtTimeZone - SUPPORTS_CONVERT_TIMEZONE = False - - # Whether MEDIAN(expr) is supported; if not, it will be generated as PERCENTILE_CONT(expr, 0.5) - SUPPORTS_MEDIAN = True - - # Whether UNIX_SECONDS(timestamp) is supported - SUPPORTS_UNIX_SECONDS = False - - # Whether to wrap in `AlterSet`, e.g., ALTER ... SET () - ALTER_SET_WRAPPED = False - - # The name to generate for the JSONPath expression. If `None`, only `this` will be generated - PARSE_JSON_NAME: t.Optional[str] = "PARSE_JSON" - - # The function name of the exp.ArraySize expression - ARRAY_SIZE_NAME: str = "ARRAY_LENGTH" - - # The syntax to use when altering the type of a column - ALTER_SET_TYPE = "SET DATA TYPE" - - # Whether exp.ArraySize should generate the dimension arg too (valid for Postgres & DuckDB) - # None -> Doesn't support it at all - # False (DuckDB) -> Has backwards-compatible support, but preferably generated without - # True (Postgres) -> Explicitly requires it - ARRAY_SIZE_DIM_REQUIRED: t.Optional[bool] = None - - TYPE_MAPPING = { - exp.DataType.Type.DATETIME2: "TIMESTAMP", - exp.DataType.Type.NCHAR: "CHAR", - exp.DataType.Type.NVARCHAR: "VARCHAR", - exp.DataType.Type.MEDIUMTEXT: "TEXT", - exp.DataType.Type.LONGTEXT: "TEXT", - exp.DataType.Type.TINYTEXT: "TEXT", - exp.DataType.Type.BLOB: "VARBINARY", - exp.DataType.Type.MEDIUMBLOB: "BLOB", - exp.DataType.Type.LONGBLOB: "BLOB", - exp.DataType.Type.TINYBLOB: "BLOB", - exp.DataType.Type.INET: "INET", - exp.DataType.Type.ROWVERSION: "VARBINARY", - exp.DataType.Type.SMALLDATETIME: "TIMESTAMP", - } - - TIME_PART_SINGULARS = { - "MICROSECONDS": "MICROSECOND", - "SECONDS": "SECOND", - "MINUTES": "MINUTE", - "HOURS": "HOUR", - "DAYS": "DAY", - "WEEKS": "WEEK", - "MONTHS": "MONTH", - "QUARTERS": "QUARTER", - "YEARS": "YEAR", - } - - AFTER_HAVING_MODIFIER_TRANSFORMS = { - "cluster": lambda self, e: self.sql(e, "cluster"), - "distribute": lambda self, e: self.sql(e, "distribute"), - "sort": lambda self, e: self.sql(e, "sort"), - "windows": lambda self, e: ( - self.seg("WINDOW ") + self.expressions(e, key="windows", flat=True) - if e.args.get("windows") - else "" - ), - "qualify": lambda self, e: self.sql(e, "qualify"), - } - - TOKEN_MAPPING: t.Dict[TokenType, str] = {} - - STRUCT_DELIMITER = ("<", ">") - - PARAMETER_TOKEN = "@" - NAMED_PLACEHOLDER_TOKEN = ":" - - EXPRESSION_PRECEDES_PROPERTIES_CREATABLES: t.Set[str] = set() - - PROPERTIES_LOCATION = { - exp.AllowedValuesProperty: exp.Properties.Location.POST_SCHEMA, - exp.AlgorithmProperty: exp.Properties.Location.POST_CREATE, - exp.AutoIncrementProperty: exp.Properties.Location.POST_SCHEMA, - exp.AutoRefreshProperty: exp.Properties.Location.POST_SCHEMA, - exp.BackupProperty: exp.Properties.Location.POST_SCHEMA, - exp.BlockCompressionProperty: exp.Properties.Location.POST_NAME, - exp.CharacterSetProperty: exp.Properties.Location.POST_SCHEMA, - exp.ChecksumProperty: exp.Properties.Location.POST_NAME, - exp.CollateProperty: exp.Properties.Location.POST_SCHEMA, - exp.CopyGrantsProperty: exp.Properties.Location.POST_SCHEMA, - exp.Cluster: exp.Properties.Location.POST_SCHEMA, - exp.ClusteredByProperty: exp.Properties.Location.POST_SCHEMA, - exp.DistributedByProperty: exp.Properties.Location.POST_SCHEMA, - exp.DuplicateKeyProperty: exp.Properties.Location.POST_SCHEMA, - exp.DataBlocksizeProperty: exp.Properties.Location.POST_NAME, - exp.DataDeletionProperty: exp.Properties.Location.POST_SCHEMA, - exp.DefinerProperty: exp.Properties.Location.POST_CREATE, - exp.DictRange: exp.Properties.Location.POST_SCHEMA, - exp.DictProperty: exp.Properties.Location.POST_SCHEMA, - exp.DynamicProperty: exp.Properties.Location.POST_CREATE, - exp.DistKeyProperty: exp.Properties.Location.POST_SCHEMA, - exp.DistStyleProperty: exp.Properties.Location.POST_SCHEMA, - exp.EmptyProperty: exp.Properties.Location.POST_SCHEMA, - exp.EncodeProperty: exp.Properties.Location.POST_EXPRESSION, - exp.EngineProperty: exp.Properties.Location.POST_SCHEMA, - exp.EnviromentProperty: exp.Properties.Location.POST_SCHEMA, - exp.ExecuteAsProperty: exp.Properties.Location.POST_SCHEMA, - exp.ExternalProperty: exp.Properties.Location.POST_CREATE, - exp.FallbackProperty: exp.Properties.Location.POST_NAME, - exp.FileFormatProperty: exp.Properties.Location.POST_WITH, - exp.FreespaceProperty: exp.Properties.Location.POST_NAME, - exp.GlobalProperty: exp.Properties.Location.POST_CREATE, - exp.HeapProperty: exp.Properties.Location.POST_WITH, - exp.InheritsProperty: exp.Properties.Location.POST_SCHEMA, - exp.IcebergProperty: exp.Properties.Location.POST_CREATE, - exp.IncludeProperty: exp.Properties.Location.POST_SCHEMA, - exp.InputModelProperty: exp.Properties.Location.POST_SCHEMA, - exp.IsolatedLoadingProperty: exp.Properties.Location.POST_NAME, - exp.JournalProperty: exp.Properties.Location.POST_NAME, - exp.LanguageProperty: exp.Properties.Location.POST_SCHEMA, - exp.LikeProperty: exp.Properties.Location.POST_SCHEMA, - exp.LocationProperty: exp.Properties.Location.POST_SCHEMA, - exp.LockProperty: exp.Properties.Location.POST_SCHEMA, - exp.LockingProperty: exp.Properties.Location.POST_ALIAS, - exp.LogProperty: exp.Properties.Location.POST_NAME, - exp.MaterializedProperty: exp.Properties.Location.POST_CREATE, - exp.MergeBlockRatioProperty: exp.Properties.Location.POST_NAME, - exp.NoPrimaryIndexProperty: exp.Properties.Location.POST_EXPRESSION, - exp.OnProperty: exp.Properties.Location.POST_SCHEMA, - exp.OnCommitProperty: exp.Properties.Location.POST_EXPRESSION, - exp.Order: exp.Properties.Location.POST_SCHEMA, - exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA, - exp.PartitionedByProperty: exp.Properties.Location.POST_WITH, - exp.PartitionedOfProperty: exp.Properties.Location.POST_SCHEMA, - exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA, - exp.Property: exp.Properties.Location.POST_WITH, - exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA, - exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA, - exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA, - exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA, - exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA, - exp.SampleProperty: exp.Properties.Location.POST_SCHEMA, - exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA, - exp.SecureProperty: exp.Properties.Location.POST_CREATE, - exp.SecurityProperty: exp.Properties.Location.POST_SCHEMA, - exp.SerdeProperties: exp.Properties.Location.POST_SCHEMA, - exp.Set: exp.Properties.Location.POST_SCHEMA, - exp.SettingsProperty: exp.Properties.Location.POST_SCHEMA, - exp.SetProperty: exp.Properties.Location.POST_CREATE, - exp.SetConfigProperty: exp.Properties.Location.POST_SCHEMA, - exp.SharingProperty: exp.Properties.Location.POST_EXPRESSION, - exp.SequenceProperties: exp.Properties.Location.POST_EXPRESSION, - exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA, - exp.SqlReadWriteProperty: exp.Properties.Location.POST_SCHEMA, - exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE, - exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA, - exp.StorageHandlerProperty: exp.Properties.Location.POST_SCHEMA, - exp.StreamingTableProperty: exp.Properties.Location.POST_CREATE, - exp.StrictProperty: exp.Properties.Location.POST_SCHEMA, - exp.Tags: exp.Properties.Location.POST_WITH, - exp.TemporaryProperty: exp.Properties.Location.POST_CREATE, - exp.ToTableProperty: exp.Properties.Location.POST_SCHEMA, - exp.TransientProperty: exp.Properties.Location.POST_CREATE, - exp.TransformModelProperty: exp.Properties.Location.POST_SCHEMA, - exp.MergeTreeTTL: exp.Properties.Location.POST_SCHEMA, - exp.UnloggedProperty: exp.Properties.Location.POST_CREATE, - exp.UsingTemplateProperty: exp.Properties.Location.POST_SCHEMA, - exp.ViewAttributeProperty: exp.Properties.Location.POST_SCHEMA, - exp.VolatileProperty: exp.Properties.Location.POST_CREATE, - exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION, - exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME, - exp.WithProcedureOptions: exp.Properties.Location.POST_SCHEMA, - exp.WithSchemaBindingProperty: exp.Properties.Location.POST_SCHEMA, - exp.WithSystemVersioningProperty: exp.Properties.Location.POST_SCHEMA, - exp.ForceProperty: exp.Properties.Location.POST_CREATE, - } - - # Keywords that can't be used as unquoted identifier names - RESERVED_KEYWORDS: t.Set[str] = set() - - # Expressions whose comments are separated from them for better formatting - WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( - exp.Command, - exp.Create, - exp.Describe, - exp.Delete, - exp.Drop, - exp.From, - exp.Insert, - exp.Join, - exp.MultitableInserts, - exp.Select, - exp.SetOperation, - exp.Update, - exp.Where, - exp.With, - ) - - # Expressions that should not have their comments generated in maybe_comment - EXCLUDE_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( - exp.Binary, - exp.SetOperation, - ) - - # Expressions that can remain unwrapped when appearing in the context of an INTERVAL - UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = ( - exp.Column, - exp.Literal, - exp.Neg, - exp.Paren, - ) - - PARAMETERIZABLE_TEXT_TYPES = { - exp.DataType.Type.NVARCHAR, - exp.DataType.Type.VARCHAR, - exp.DataType.Type.CHAR, - exp.DataType.Type.NCHAR, - } - - # Expressions that need to have all CTEs under them bubbled up to them - EXPRESSIONS_WITHOUT_NESTED_CTES: t.Set[t.Type[exp.Expression]] = set() - - RESPECT_IGNORE_NULLS_UNSUPPORTED_EXPRESSIONS: t.Tuple[t.Type[exp.Expression], ...] = () - - SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" - - __slots__ = ( - "pretty", - "identify", - "normalize", - "pad", - "_indent", - "normalize_functions", - "unsupported_level", - "max_unsupported", - "leading_comma", - "max_text_width", - "comments", - "dialect", - "unsupported_messages", - "_escaped_quote_end", - "_escaped_identifier_end", - "_next_name", - "_identifier_start", - "_identifier_end", - "_quote_json_path_key_using_brackets", - ) - - def __init__( - self, - pretty: t.Optional[bool] = None, - identify: str | bool = False, - normalize: bool = False, - pad: int = 2, - indent: int = 2, - normalize_functions: t.Optional[str | bool] = None, - unsupported_level: ErrorLevel = ErrorLevel.WARN, - max_unsupported: int = 3, - leading_comma: bool = False, - max_text_width: int = 80, - comments: bool = True, - dialect: DialectType = None, - ): - import sqlglot - from sqlglot.dialects import Dialect - - self.pretty = pretty if pretty is not None else sqlglot.pretty - self.identify = identify - self.normalize = normalize - self.pad = pad - self._indent = indent - self.unsupported_level = unsupported_level - self.max_unsupported = max_unsupported - self.leading_comma = leading_comma - self.max_text_width = max_text_width - self.comments = comments - self.dialect = Dialect.get_or_raise(dialect) - - # This is both a Dialect property and a Generator argument, so we prioritize the latter - self.normalize_functions = ( - self.dialect.NORMALIZE_FUNCTIONS if normalize_functions is None else normalize_functions - ) - - self.unsupported_messages: t.List[str] = [] - self._escaped_quote_end: str = ( - self.dialect.tokenizer_class.STRING_ESCAPES[0] + self.dialect.QUOTE_END - ) - self._escaped_identifier_end = self.dialect.IDENTIFIER_END * 2 - - self._next_name = name_sequence("_t") - - self._identifier_start = self.dialect.IDENTIFIER_START - self._identifier_end = self.dialect.IDENTIFIER_END - - self._quote_json_path_key_using_brackets = True - - def generate(self, expression: exp.Expression, copy: bool = True) -> str: - """ - Generates the SQL string corresponding to the given syntax tree. - - Args: - expression: The syntax tree. - copy: Whether to copy the expression. The generator performs mutations so - it is safer to copy. - - Returns: - The SQL string corresponding to `expression`. - """ - if copy: - expression = expression.copy() - - expression = self.preprocess(expression) - - self.unsupported_messages = [] - sql = self.sql(expression).strip() - - if self.pretty: - sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n") - - if self.unsupported_level == ErrorLevel.IGNORE: - return sql - - if self.unsupported_level == ErrorLevel.WARN: - for msg in self.unsupported_messages: - logger.warning(msg) - elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: - raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported)) - - return sql - - def preprocess(self, expression: exp.Expression) -> exp.Expression: - """Apply generic preprocessing transformations to a given expression.""" - expression = self._move_ctes_to_top_level(expression) - - if self.ENSURE_BOOLS: - from sqlglot.transforms import ensure_bools - - expression = ensure_bools(expression) - - return expression - - def _move_ctes_to_top_level(self, expression: E) -> E: - if ( - not expression.parent - and type(expression) in self.EXPRESSIONS_WITHOUT_NESTED_CTES - and any(node.parent is not expression for node in expression.find_all(exp.With)) - ): - from sqlglot.transforms import move_ctes_to_top_level - - expression = move_ctes_to_top_level(expression) - return expression - - def unsupported(self, message: str) -> None: - if self.unsupported_level == ErrorLevel.IMMEDIATE: - raise UnsupportedError(message) - self.unsupported_messages.append(message) - - def sep(self, sep: str = " ") -> str: - return f"{sep.strip()}\n" if self.pretty else sep - - def seg(self, sql: str, sep: str = " ") -> str: - return f"{self.sep(sep)}{sql}" - - def sanitize_comment(self, comment: str) -> str: - comment = " " + comment if comment[0].strip() else comment - comment = comment + " " if comment[-1].strip() else comment - - if not self.dialect.tokenizer_class.NESTED_COMMENTS: - # Necessary workaround to avoid syntax errors due to nesting: /* ... */ ... */ - comment = comment.replace("*/", "* /") - - return comment - - def maybe_comment( - self, - sql: str, - expression: t.Optional[exp.Expression] = None, - comments: t.Optional[t.List[str]] = None, - separated: bool = False, - ) -> str: - comments = ( - ((expression and expression.comments) if comments is None else comments) # type: ignore - if self.comments - else None - ) - - if not comments or isinstance(expression, self.EXCLUDE_COMMENTS): - return sql - - comments_sql = " ".join( - f"/*{self.sanitize_comment(comment)}*/" for comment in comments if comment - ) - - if not comments_sql: - return sql - - comments_sql = self._replace_line_breaks(comments_sql) - - if separated or isinstance(expression, self.WITH_SEPARATED_COMMENTS): - return ( - f"{self.sep()}{comments_sql}{sql}" - if not sql or sql[0].isspace() - else f"{comments_sql}{self.sep()}{sql}" - ) - - return f"{sql} {comments_sql}" - - def wrap(self, expression: exp.Expression | str) -> str: - this_sql = ( - self.sql(expression) - if isinstance(expression, exp.UNWRAPPED_QUERIES) - else self.sql(expression, "this") - ) - if not this_sql: - return "()" - - this_sql = self.indent(this_sql, level=1, pad=0) - return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}" - - def no_identify(self, func: t.Callable[..., str], *args, **kwargs) -> str: - original = self.identify - self.identify = False - result = func(*args, **kwargs) - self.identify = original - return result - - def normalize_func(self, name: str) -> str: - if self.normalize_functions == "upper" or self.normalize_functions is True: - return name.upper() - if self.normalize_functions == "lower": - return name.lower() - return name - - def indent( - self, - sql: str, - level: int = 0, - pad: t.Optional[int] = None, - skip_first: bool = False, - skip_last: bool = False, - ) -> str: - if not self.pretty or not sql: - return sql - - pad = self.pad if pad is None else pad - lines = sql.split("\n") - - return "\n".join( - ( - line - if (skip_first and i == 0) or (skip_last and i == len(lines) - 1) - else f"{' ' * (level * self._indent + pad)}{line}" - ) - for i, line in enumerate(lines) - ) - - def sql( - self, - expression: t.Optional[str | exp.Expression], - key: t.Optional[str] = None, - comment: bool = True, - ) -> str: - if not expression: - return "" - - if isinstance(expression, str): - return expression - - if key: - value = expression.args.get(key) - if value: - return self.sql(value) - return "" - - transform = self.TRANSFORMS.get(expression.__class__) - - if callable(transform): - sql = transform(self, expression) - elif isinstance(expression, exp.Expression): - exp_handler_name = f"{expression.key}_sql" - - if hasattr(self, exp_handler_name): - sql = getattr(self, exp_handler_name)(expression) - elif isinstance(expression, exp.Func): - sql = self.function_fallback_sql(expression) - elif isinstance(expression, exp.Property): - sql = self.property_sql(expression) - else: - raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") - else: - raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") - - return self.maybe_comment(sql, expression) if self.comments and comment else sql - - def uncache_sql(self, expression: exp.Uncache) -> str: - table = self.sql(expression, "this") - exists_sql = " IF EXISTS" if expression.args.get("exists") else "" - return f"UNCACHE TABLE{exists_sql} {table}" - - def cache_sql(self, expression: exp.Cache) -> str: - lazy = " LAZY" if expression.args.get("lazy") else "" - table = self.sql(expression, "this") - options = expression.args.get("options") - options = f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" if options else "" - sql = self.sql(expression, "expression") - sql = f" AS{self.sep()}{sql}" if sql else "" - sql = f"CACHE{lazy} TABLE {table}{options}{sql}" - return self.prepend_ctes(expression, sql) - - def characterset_sql(self, expression: exp.CharacterSet) -> str: - if isinstance(expression.parent, exp.Cast): - return f"CHAR CHARACTER SET {self.sql(expression, 'this')}" - default = "DEFAULT " if expression.args.get("default") else "" - return f"{default}CHARACTER SET={self.sql(expression, 'this')}" - - def column_parts(self, expression: exp.Column) -> str: - return ".".join( - self.sql(part) - for part in ( - expression.args.get("catalog"), - expression.args.get("db"), - expression.args.get("table"), - expression.args.get("this"), - ) - if part - ) - - def column_sql(self, expression: exp.Column) -> str: - join_mark = " (+)" if expression.args.get("join_mark") else "" - - if join_mark and not self.dialect.SUPPORTS_COLUMN_JOIN_MARKS: - join_mark = "" - self.unsupported("Outer join syntax using the (+) operator is not supported.") - - return f"{self.column_parts(expression)}{join_mark}" - - def columnposition_sql(self, expression: exp.ColumnPosition) -> str: - this = self.sql(expression, "this") - this = f" {this}" if this else "" - position = self.sql(expression, "position") - return f"{position}{this}" - - def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: - column = self.sql(expression, "this") - kind = self.sql(expression, "kind") - constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) - exists = "IF NOT EXISTS " if expression.args.get("exists") else "" - kind = f"{sep}{kind}" if kind else "" - constraints = f" {constraints}" if constraints else "" - position = self.sql(expression, "position") - position = f" {position}" if position else "" - - if expression.find(exp.ComputedColumnConstraint) and not self.COMPUTED_COLUMN_WITH_TYPE: - kind = "" - - return f"{exists}{column}{kind}{constraints}{position}" - - def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: - this = self.sql(expression, "this") - kind_sql = self.sql(expression, "kind").strip() - return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql - - def computedcolumnconstraint_sql(self, expression: exp.ComputedColumnConstraint) -> str: - this = self.sql(expression, "this") - if expression.args.get("not_null"): - persisted = " PERSISTED NOT NULL" - elif expression.args.get("persisted"): - persisted = " PERSISTED" - else: - persisted = "" - - return f"AS {this}{persisted}" - - def autoincrementcolumnconstraint_sql(self, _) -> str: - return self.token_sql(TokenType.AUTO_INCREMENT) - - def compresscolumnconstraint_sql(self, expression: exp.CompressColumnConstraint) -> str: - if isinstance(expression.this, list): - this = self.wrap(self.expressions(expression, key="this", flat=True)) - else: - this = self.sql(expression, "this") - - return f"COMPRESS {this}" - - def generatedasidentitycolumnconstraint_sql( - self, expression: exp.GeneratedAsIdentityColumnConstraint - ) -> str: - this = "" - if expression.this is not None: - on_null = " ON NULL" if expression.args.get("on_null") else "" - this = " ALWAYS" if expression.this else f" BY DEFAULT{on_null}" - - start = expression.args.get("start") - start = f"START WITH {start}" if start else "" - increment = expression.args.get("increment") - increment = f" INCREMENT BY {increment}" if increment else "" - minvalue = expression.args.get("minvalue") - minvalue = f" MINVALUE {minvalue}" if minvalue else "" - maxvalue = expression.args.get("maxvalue") - maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else "" - cycle = expression.args.get("cycle") - cycle_sql = "" - - if cycle is not None: - cycle_sql = f"{' NO' if not cycle else ''} CYCLE" - cycle_sql = cycle_sql.strip() if not start and not increment else cycle_sql - - sequence_opts = "" - if start or increment or cycle_sql: - sequence_opts = f"{start}{increment}{minvalue}{maxvalue}{cycle_sql}" - sequence_opts = f" ({sequence_opts.strip()})" - - expr = self.sql(expression, "expression") - expr = f"({expr})" if expr else "IDENTITY" - - return f"GENERATED{this} AS {expr}{sequence_opts}" - - def generatedasrowcolumnconstraint_sql( - self, expression: exp.GeneratedAsRowColumnConstraint - ) -> str: - start = "START" if expression.args.get("start") else "END" - hidden = " HIDDEN" if expression.args.get("hidden") else "" - return f"GENERATED ALWAYS AS ROW {start}{hidden}" - - def periodforsystemtimeconstraint_sql( - self, expression: exp.PeriodForSystemTimeConstraint - ) -> str: - return f"PERIOD FOR SYSTEM_TIME ({self.sql(expression, 'this')}, {self.sql(expression, 'expression')})" - - def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: - return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" - - def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str: - desc = expression.args.get("desc") - if desc is not None: - return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" - options = self.expressions(expression, key="options", flat=True, sep=" ") - options = f" {options}" if options else "" - return f"PRIMARY KEY{options}" - - def uniquecolumnconstraint_sql(self, expression: exp.UniqueColumnConstraint) -> str: - this = self.sql(expression, "this") - this = f" {this}" if this else "" - index_type = expression.args.get("index_type") - index_type = f" USING {index_type}" if index_type else "" - on_conflict = self.sql(expression, "on_conflict") - on_conflict = f" {on_conflict}" if on_conflict else "" - nulls_sql = " NULLS NOT DISTINCT" if expression.args.get("nulls") else "" - options = self.expressions(expression, key="options", flat=True, sep=" ") - options = f" {options}" if options else "" - return f"UNIQUE{nulls_sql}{this}{index_type}{on_conflict}{options}" - - def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: - return self.sql(expression, "this") - - def create_sql(self, expression: exp.Create) -> str: - kind = self.sql(expression, "kind") - kind = self.dialect.INVERSE_CREATABLE_KIND_MAPPING.get(kind) or kind - properties = expression.args.get("properties") - properties_locs = self.locate_properties(properties) if properties else defaultdict() - - this = self.createable_sql(expression, properties_locs) - - properties_sql = "" - if properties_locs.get(exp.Properties.Location.POST_SCHEMA) or properties_locs.get( - exp.Properties.Location.POST_WITH - ): - properties_sql = self.sql( - exp.Properties( - expressions=[ - *properties_locs[exp.Properties.Location.POST_SCHEMA], - *properties_locs[exp.Properties.Location.POST_WITH], - ] - ) - ) - - if properties_locs.get(exp.Properties.Location.POST_SCHEMA): - properties_sql = self.sep() + properties_sql - elif not self.pretty: - # Standalone POST_WITH properties need a leading whitespace in non-pretty mode - properties_sql = f" {properties_sql}" - - begin = " BEGIN" if expression.args.get("begin") else "" - end = " END" if expression.args.get("end") else "" - - expression_sql = self.sql(expression, "expression") - if expression_sql: - expression_sql = f"{begin}{self.sep()}{expression_sql}{end}" - - if self.CREATE_FUNCTION_RETURN_AS or not isinstance(expression.expression, exp.Return): - postalias_props_sql = "" - if properties_locs.get(exp.Properties.Location.POST_ALIAS): - postalias_props_sql = self.properties( - exp.Properties( - expressions=properties_locs[exp.Properties.Location.POST_ALIAS] - ), - wrapped=False, - ) - postalias_props_sql = f" {postalias_props_sql}" if postalias_props_sql else "" - expression_sql = f" AS{postalias_props_sql}{expression_sql}" - - postindex_props_sql = "" - if properties_locs.get(exp.Properties.Location.POST_INDEX): - postindex_props_sql = self.properties( - exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_INDEX]), - wrapped=False, - prefix=" ", - ) - - indexes = self.expressions(expression, key="indexes", indent=False, sep=" ") - indexes = f" {indexes}" if indexes else "" - index_sql = indexes + postindex_props_sql - - replace = " OR REPLACE" if expression.args.get("replace") else "" - refresh = " OR REFRESH" if expression.args.get("refresh") else "" - unique = " UNIQUE" if expression.args.get("unique") else "" - - clustered = expression.args.get("clustered") - if clustered is None: - clustered_sql = "" - elif clustered: - clustered_sql = " CLUSTERED COLUMNSTORE" - else: - clustered_sql = " NONCLUSTERED COLUMNSTORE" - - postcreate_props_sql = "" - if properties_locs.get(exp.Properties.Location.POST_CREATE): - postcreate_props_sql = self.properties( - exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_CREATE]), - sep=" ", - prefix=" ", - wrapped=False, - ) - - modifiers = "".join((clustered_sql, replace, refresh, unique, postcreate_props_sql)) - - postexpression_props_sql = "" - if properties_locs.get(exp.Properties.Location.POST_EXPRESSION): - postexpression_props_sql = self.properties( - exp.Properties( - expressions=properties_locs[exp.Properties.Location.POST_EXPRESSION] - ), - sep=" ", - prefix=" ", - wrapped=False, - ) - - concurrently = " CONCURRENTLY" if expression.args.get("concurrently") else "" - exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" - no_schema_binding = ( - " WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else "" - ) - - clone = self.sql(expression, "clone") - clone = f" {clone}" if clone else "" - - if kind in self.EXPRESSION_PRECEDES_PROPERTIES_CREATABLES: - properties_expression = f"{expression_sql}{properties_sql}" - else: - properties_expression = f"{properties_sql}{expression_sql}" - - expression_sql = f"CREATE{modifiers} {kind}{concurrently}{exists_sql} {this}{properties_expression}{postexpression_props_sql}{index_sql}{no_schema_binding}{clone}" - return self.prepend_ctes(expression, expression_sql) - - def sequenceproperties_sql(self, expression: exp.SequenceProperties) -> str: - start = self.sql(expression, "start") - start = f"START WITH {start}" if start else "" - increment = self.sql(expression, "increment") - increment = f" INCREMENT BY {increment}" if increment else "" - minvalue = self.sql(expression, "minvalue") - minvalue = f" MINVALUE {minvalue}" if minvalue else "" - maxvalue = self.sql(expression, "maxvalue") - maxvalue = f" MAXVALUE {maxvalue}" if maxvalue else "" - owned = self.sql(expression, "owned") - owned = f" OWNED BY {owned}" if owned else "" - - cache = expression.args.get("cache") - if cache is None: - cache_str = "" - elif cache is True: - cache_str = " CACHE" - else: - cache_str = f" CACHE {cache}" - - options = self.expressions(expression, key="options", flat=True, sep=" ") - options = f" {options}" if options else "" - - return f"{start}{increment}{minvalue}{maxvalue}{cache_str}{options}{owned}".lstrip() - - def clone_sql(self, expression: exp.Clone) -> str: - this = self.sql(expression, "this") - shallow = "SHALLOW " if expression.args.get("shallow") else "" - keyword = "COPY" if expression.args.get("copy") and self.SUPPORTS_TABLE_COPY else "CLONE" - return f"{shallow}{keyword} {this}" - - def describe_sql(self, expression: exp.Describe) -> str: - style = expression.args.get("style") - style = f" {style}" if style else "" - partition = self.sql(expression, "partition") - partition = f" {partition}" if partition else "" - format = self.sql(expression, "format") - format = f" {format}" if format else "" - - return f"DESCRIBE{style}{format} {self.sql(expression, 'this')}{partition}" - - def heredoc_sql(self, expression: exp.Heredoc) -> str: - tag = self.sql(expression, "tag") - return f"${tag}${self.sql(expression, 'this')}${tag}$" - - def prepend_ctes(self, expression: exp.Expression, sql: str) -> str: - with_ = self.sql(expression, "with") - if with_: - sql = f"{with_}{self.sep()}{sql}" - return sql - - def with_sql(self, expression: exp.With) -> str: - sql = self.expressions(expression, flat=True) - recursive = ( - "RECURSIVE " - if self.CTE_RECURSIVE_KEYWORD_REQUIRED and expression.args.get("recursive") - else "" - ) - search = self.sql(expression, "search") - search = f" {search}" if search else "" - - return f"WITH {recursive}{sql}{search}" - - def cte_sql(self, expression: exp.CTE) -> str: - alias = expression.args.get("alias") - if alias: - alias.add_comments(expression.pop_comments()) - - alias_sql = self.sql(expression, "alias") - - materialized = expression.args.get("materialized") - if materialized is False: - materialized = "NOT MATERIALIZED " - elif materialized: - materialized = "MATERIALIZED " - - return f"{alias_sql} AS {materialized or ''}{self.wrap(expression)}" - - def tablealias_sql(self, expression: exp.TableAlias) -> str: - alias = self.sql(expression, "this") - columns = self.expressions(expression, key="columns", flat=True) - columns = f"({columns})" if columns else "" - - if columns and not self.SUPPORTS_TABLE_ALIAS_COLUMNS: - columns = "" - self.unsupported("Named columns are not supported in table alias.") - - if not alias and not self.dialect.UNNEST_COLUMN_ONLY: - alias = self._next_name() - - return f"{alias}{columns}" - - def bitstring_sql(self, expression: exp.BitString) -> str: - this = self.sql(expression, "this") - if self.dialect.BIT_START: - return f"{self.dialect.BIT_START}{this}{self.dialect.BIT_END}" - return f"{int(this, 2)}" - - def hexstring_sql( - self, expression: exp.HexString, binary_function_repr: t.Optional[str] = None - ) -> str: - this = self.sql(expression, "this") - is_integer_type = expression.args.get("is_integer") - - if (is_integer_type and not self.dialect.HEX_STRING_IS_INTEGER_TYPE) or ( - not self.dialect.HEX_START and not binary_function_repr - ): - # Integer representation will be returned if: - # - The read dialect treats the hex value as integer literal but not the write - # - The transpilation is not supported (write dialect hasn't set HEX_START or the param flag) - return f"{int(this, 16)}" - - if not is_integer_type: - # Read dialect treats the hex value as BINARY/BLOB - if binary_function_repr: - # The write dialect supports the transpilation to its equivalent BINARY/BLOB - return self.func(binary_function_repr, exp.Literal.string(this)) - if self.dialect.HEX_STRING_IS_INTEGER_TYPE: - # The write dialect does not support the transpilation, it'll treat the hex value as INTEGER - self.unsupported("Unsupported transpilation from BINARY/BLOB hex string") - - return f"{self.dialect.HEX_START}{this}{self.dialect.HEX_END}" - - def bytestring_sql(self, expression: exp.ByteString) -> str: - this = self.sql(expression, "this") - if self.dialect.BYTE_START: - return f"{self.dialect.BYTE_START}{this}{self.dialect.BYTE_END}" - return this - - def unicodestring_sql(self, expression: exp.UnicodeString) -> str: - this = self.sql(expression, "this") - escape = expression.args.get("escape") - - if self.dialect.UNICODE_START: - escape_substitute = r"\\\1" - left_quote, right_quote = self.dialect.UNICODE_START, self.dialect.UNICODE_END - else: - escape_substitute = r"\\u\1" - left_quote, right_quote = self.dialect.QUOTE_START, self.dialect.QUOTE_END - - if escape: - escape_pattern = re.compile(rf"{escape.name}(\d+)") - escape_sql = f" UESCAPE {self.sql(escape)}" if self.SUPPORTS_UESCAPE else "" - else: - escape_pattern = ESCAPED_UNICODE_RE - escape_sql = "" - - if not self.dialect.UNICODE_START or (escape and not self.SUPPORTS_UESCAPE): - this = escape_pattern.sub(escape_substitute, this) - - return f"{left_quote}{this}{right_quote}{escape_sql}" - - def rawstring_sql(self, expression: exp.RawString) -> str: - string = expression.this - if "\\" in self.dialect.tokenizer_class.STRING_ESCAPES: - string = string.replace("\\", "\\\\") - - string = self.escape_str(string, escape_backslash=False) - return f"{self.dialect.QUOTE_START}{string}{self.dialect.QUOTE_END}" - - def datatypeparam_sql(self, expression: exp.DataTypeParam) -> str: - this = self.sql(expression, "this") - specifier = self.sql(expression, "expression") - specifier = f" {specifier}" if specifier and self.DATA_TYPE_SPECIFIERS_ALLOWED else "" - return f"{this}{specifier}" - - def datatype_sql(self, expression: exp.DataType) -> str: - nested = "" - values = "" - interior = self.expressions(expression, flat=True) - - type_value = expression.this - if type_value == exp.DataType.Type.USERDEFINED and expression.args.get("kind"): - type_sql = self.sql(expression, "kind") - else: - type_sql = ( - self.TYPE_MAPPING.get(type_value, type_value.value) - if isinstance(type_value, exp.DataType.Type) - else type_value - ) - - if interior: - if expression.args.get("nested"): - nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" - if expression.args.get("values") is not None: - delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")") - values = self.expressions(expression, key="values", flat=True) - values = f"{delimiters[0]}{values}{delimiters[1]}" - elif type_value == exp.DataType.Type.INTERVAL: - nested = f" {interior}" - else: - nested = f"({interior})" - - type_sql = f"{type_sql}{nested}{values}" - if self.TZ_TO_WITH_TIME_ZONE and type_value in ( - exp.DataType.Type.TIMETZ, - exp.DataType.Type.TIMESTAMPTZ, - ): - type_sql = f"{type_sql} WITH TIME ZONE" - - return type_sql - - def directory_sql(self, expression: exp.Directory) -> str: - local = "LOCAL " if expression.args.get("local") else "" - row_format = self.sql(expression, "row_format") - row_format = f" {row_format}" if row_format else "" - return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}" - - def delete_sql(self, expression: exp.Delete) -> str: - this = self.sql(expression, "this") - this = f" FROM {this}" if this else "" - using = self.sql(expression, "using") - using = f" USING {using}" if using else "" - cluster = self.sql(expression, "cluster") - cluster = f" {cluster}" if cluster else "" - where = self.sql(expression, "where") - returning = self.sql(expression, "returning") - limit = self.sql(expression, "limit") - tables = self.expressions(expression, key="tables") - tables = f" {tables}" if tables else "" - if self.RETURNING_END: - expression_sql = f"{this}{using}{cluster}{where}{returning}{limit}" - else: - expression_sql = f"{returning}{this}{using}{cluster}{where}{limit}" - return self.prepend_ctes(expression, f"DELETE{tables}{expression_sql}") - - def drop_sql(self, expression: exp.Drop) -> str: - this = self.sql(expression, "this") - expressions = self.expressions(expression, flat=True) - expressions = f" ({expressions})" if expressions else "" - kind = expression.args["kind"] - kind = self.dialect.INVERSE_CREATABLE_KIND_MAPPING.get(kind) or kind - exists_sql = " IF EXISTS " if expression.args.get("exists") else " " - concurrently_sql = " CONCURRENTLY" if expression.args.get("concurrently") else "" - on_cluster = self.sql(expression, "cluster") - on_cluster = f" {on_cluster}" if on_cluster else "" - temporary = " TEMPORARY" if expression.args.get("temporary") else "" - materialized = " MATERIALIZED" if expression.args.get("materialized") else "" - cascade = " CASCADE" if expression.args.get("cascade") else "" - constraints = " CONSTRAINTS" if expression.args.get("constraints") else "" - purge = " PURGE" if expression.args.get("purge") else "" - return f"DROP{temporary}{materialized} {kind}{concurrently_sql}{exists_sql}{this}{on_cluster}{expressions}{cascade}{constraints}{purge}" - - def set_operation(self, expression: exp.SetOperation) -> str: - op_type = type(expression) - op_name = op_type.key.upper() - - distinct = expression.args.get("distinct") - if ( - distinct is False - and op_type in (exp.Except, exp.Intersect) - and not self.EXCEPT_INTERSECT_SUPPORT_ALL_CLAUSE - ): - self.unsupported(f"{op_name} ALL is not supported") - - default_distinct = self.dialect.SET_OP_DISTINCT_BY_DEFAULT[op_type] - - if distinct is None: - distinct = default_distinct - if distinct is None: - self.unsupported(f"{op_name} requires DISTINCT or ALL to be specified") - - if distinct is default_distinct: - distinct_or_all = "" - else: - distinct_or_all = " DISTINCT" if distinct else " ALL" - - side_kind = " ".join(filter(None, [expression.side, expression.kind])) - side_kind = f"{side_kind} " if side_kind else "" - - by_name = " BY NAME" if expression.args.get("by_name") else "" - on = self.expressions(expression, key="on", flat=True) - on = f" ON ({on})" if on else "" - - return f"{side_kind}{op_name}{distinct_or_all}{by_name}{on}" - - def set_operations(self, expression: exp.SetOperation) -> str: - if not self.SET_OP_MODIFIERS: - limit = expression.args.get("limit") - order = expression.args.get("order") - - if limit or order: - select = self._move_ctes_to_top_level( - exp.subquery(expression, "_l_0", copy=False).select("*", copy=False) - ) - - if limit: - select = select.limit(limit.pop(), copy=False) - if order: - select = select.order_by(order.pop(), copy=False) - return self.sql(select) - - sqls: t.List[str] = [] - stack: t.List[t.Union[str, exp.Expression]] = [expression] - - while stack: - node = stack.pop() - - if isinstance(node, exp.SetOperation): - stack.append(node.expression) - stack.append( - self.maybe_comment( - self.set_operation(node), comments=node.comments, separated=True - ) - ) - stack.append(node.this) - else: - sqls.append(self.sql(node)) - - this = self.sep().join(sqls) - this = self.query_modifiers(expression, this) - return self.prepend_ctes(expression, this) - - def fetch_sql(self, expression: exp.Fetch) -> str: - direction = expression.args.get("direction") - direction = f" {direction}" if direction else "" - count = self.sql(expression, "count") - count = f" {count}" if count else "" - limit_options = self.sql(expression, "limit_options") - limit_options = f"{limit_options}" if limit_options else " ROWS ONLY" - return f"{self.seg('FETCH')}{direction}{count}{limit_options}" - - def limitoptions_sql(self, expression: exp.LimitOptions) -> str: - percent = " PERCENT" if expression.args.get("percent") else "" - rows = " ROWS" if expression.args.get("rows") else "" - with_ties = " WITH TIES" if expression.args.get("with_ties") else "" - if not with_ties and rows: - with_ties = " ONLY" - return f"{percent}{rows}{with_ties}" - - def filter_sql(self, expression: exp.Filter) -> str: - if self.AGGREGATE_FILTER_SUPPORTED: - this = self.sql(expression, "this") - where = self.sql(expression, "expression").strip() - return f"{this} FILTER({where})" - - agg = expression.this - agg_arg = agg.this - cond = expression.expression.this - agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy())) - return self.sql(agg) - - def hint_sql(self, expression: exp.Hint) -> str: - if not self.QUERY_HINTS: - self.unsupported("Hints are not supported") - return "" - - return f" /*+ {self.expressions(expression, sep=self.QUERY_HINT_SEP).strip()} */" - - def indexparameters_sql(self, expression: exp.IndexParameters) -> str: - using = self.sql(expression, "using") - using = f" USING {using}" if using else "" - columns = self.expressions(expression, key="columns", flat=True) - columns = f"({columns})" if columns else "" - partition_by = self.expressions(expression, key="partition_by", flat=True) - partition_by = f" PARTITION BY {partition_by}" if partition_by else "" - where = self.sql(expression, "where") - include = self.expressions(expression, key="include", flat=True) - if include: - include = f" INCLUDE ({include})" - with_storage = self.expressions(expression, key="with_storage", flat=True) - with_storage = f" WITH ({with_storage})" if with_storage else "" - tablespace = self.sql(expression, "tablespace") - tablespace = f" USING INDEX TABLESPACE {tablespace}" if tablespace else "" - on = self.sql(expression, "on") - on = f" ON {on}" if on else "" - - return f"{using}{columns}{include}{with_storage}{tablespace}{partition_by}{where}{on}" - - def index_sql(self, expression: exp.Index) -> str: - unique = "UNIQUE " if expression.args.get("unique") else "" - primary = "PRIMARY " if expression.args.get("primary") else "" - amp = "AMP " if expression.args.get("amp") else "" - name = self.sql(expression, "this") - name = f"{name} " if name else "" - table = self.sql(expression, "table") - table = f"{self.INDEX_ON} {table}" if table else "" - - index = "INDEX " if not table else "" - - params = self.sql(expression, "params") - return f"{unique}{primary}{amp}{index}{name}{table}{params}" - - def identifier_sql(self, expression: exp.Identifier) -> str: - text = expression.name - lower = text.lower() - text = lower if self.normalize and not expression.quoted else text - text = text.replace(self._identifier_end, self._escaped_identifier_end) - if ( - expression.quoted - or self.dialect.can_identify(text, self.identify) - or lower in self.RESERVED_KEYWORDS - or (not self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT and text[:1].isdigit()) - ): - text = f"{self._identifier_start}{text}{self._identifier_end}" - return text - - def hex_sql(self, expression: exp.Hex) -> str: - text = self.func(self.HEX_FUNC, self.sql(expression, "this")) - if self.dialect.HEX_LOWERCASE: - text = self.func("LOWER", text) - - return text - - def lowerhex_sql(self, expression: exp.LowerHex) -> str: - text = self.func(self.HEX_FUNC, self.sql(expression, "this")) - if not self.dialect.HEX_LOWERCASE: - text = self.func("LOWER", text) - return text - - def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str: - input_format = self.sql(expression, "input_format") - input_format = f"INPUTFORMAT {input_format}" if input_format else "" - output_format = self.sql(expression, "output_format") - output_format = f"OUTPUTFORMAT {output_format}" if output_format else "" - return self.sep().join((input_format, output_format)) - - def national_sql(self, expression: exp.National, prefix: str = "N") -> str: - string = self.sql(exp.Literal.string(expression.name)) - return f"{prefix}{string}" - - def partition_sql(self, expression: exp.Partition) -> str: - partition_keyword = "SUBPARTITION" if expression.args.get("subpartition") else "PARTITION" - return f"{partition_keyword}({self.expressions(expression, flat=True)})" - - def properties_sql(self, expression: exp.Properties) -> str: - root_properties = [] - with_properties = [] - - for p in expression.expressions: - p_loc = self.PROPERTIES_LOCATION[p.__class__] - if p_loc == exp.Properties.Location.POST_WITH: - with_properties.append(p) - elif p_loc == exp.Properties.Location.POST_SCHEMA: - root_properties.append(p) - - root_props = self.root_properties(exp.Properties(expressions=root_properties)) - with_props = self.with_properties(exp.Properties(expressions=with_properties)) - - if root_props and with_props and not self.pretty: - with_props = " " + with_props - - return root_props + with_props - - def root_properties(self, properties: exp.Properties) -> str: - if properties.expressions: - return self.expressions(properties, indent=False, sep=" ") - return "" - - def properties( - self, - properties: exp.Properties, - prefix: str = "", - sep: str = ", ", - suffix: str = "", - wrapped: bool = True, - ) -> str: - if properties.expressions: - expressions = self.expressions(properties, sep=sep, indent=False) - if expressions: - expressions = self.wrap(expressions) if wrapped else expressions - return f"{prefix}{' ' if prefix.strip() else ''}{expressions}{suffix}" - return "" - - def with_properties(self, properties: exp.Properties) -> str: - return self.properties(properties, prefix=self.seg(self.WITH_PROPERTIES_PREFIX, sep="")) - - def locate_properties(self, properties: exp.Properties) -> t.DefaultDict: - properties_locs = defaultdict(list) - for p in properties.expressions: - p_loc = self.PROPERTIES_LOCATION[p.__class__] - if p_loc != exp.Properties.Location.UNSUPPORTED: - properties_locs[p_loc].append(p) - else: - self.unsupported(f"Unsupported property {p.key}") - - return properties_locs - - def property_name(self, expression: exp.Property, string_key: bool = False) -> str: - if isinstance(expression.this, exp.Dot): - return self.sql(expression, "this") - return f"'{expression.name}'" if string_key else expression.name - - def property_sql(self, expression: exp.Property) -> str: - property_cls = expression.__class__ - if property_cls == exp.Property: - return f"{self.property_name(expression)}={self.sql(expression, 'value')}" - - property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls) - if not property_name: - self.unsupported(f"Unsupported property {expression.key}") - - return f"{property_name}={self.sql(expression, 'this')}" - - def likeproperty_sql(self, expression: exp.LikeProperty) -> str: - if self.SUPPORTS_CREATE_TABLE_LIKE: - options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions) - options = f" {options}" if options else "" - - like = f"LIKE {self.sql(expression, 'this')}{options}" - if self.LIKE_PROPERTY_INSIDE_SCHEMA and not isinstance(expression.parent, exp.Schema): - like = f"({like})" - - return like - - if expression.expressions: - self.unsupported("Transpilation of LIKE property options is unsupported") - - select = exp.select("*").from_(expression.this).limit(0) - return f"AS {self.sql(select)}" - - def fallbackproperty_sql(self, expression: exp.FallbackProperty) -> str: - no = "NO " if expression.args.get("no") else "" - protection = " PROTECTION" if expression.args.get("protection") else "" - return f"{no}FALLBACK{protection}" - - def journalproperty_sql(self, expression: exp.JournalProperty) -> str: - no = "NO " if expression.args.get("no") else "" - local = expression.args.get("local") - local = f"{local} " if local else "" - dual = "DUAL " if expression.args.get("dual") else "" - before = "BEFORE " if expression.args.get("before") else "" - after = "AFTER " if expression.args.get("after") else "" - return f"{no}{local}{dual}{before}{after}JOURNAL" - - def freespaceproperty_sql(self, expression: exp.FreespaceProperty) -> str: - freespace = self.sql(expression, "this") - percent = " PERCENT" if expression.args.get("percent") else "" - return f"FREESPACE={freespace}{percent}" - - def checksumproperty_sql(self, expression: exp.ChecksumProperty) -> str: - if expression.args.get("default"): - property = "DEFAULT" - elif expression.args.get("on"): - property = "ON" - else: - property = "OFF" - return f"CHECKSUM={property}" - - def mergeblockratioproperty_sql(self, expression: exp.MergeBlockRatioProperty) -> str: - if expression.args.get("no"): - return "NO MERGEBLOCKRATIO" - if expression.args.get("default"): - return "DEFAULT MERGEBLOCKRATIO" - - percent = " PERCENT" if expression.args.get("percent") else "" - return f"MERGEBLOCKRATIO={self.sql(expression, 'this')}{percent}" - - def datablocksizeproperty_sql(self, expression: exp.DataBlocksizeProperty) -> str: - default = expression.args.get("default") - minimum = expression.args.get("minimum") - maximum = expression.args.get("maximum") - if default or minimum or maximum: - if default: - prop = "DEFAULT" - elif minimum: - prop = "MINIMUM" - else: - prop = "MAXIMUM" - return f"{prop} DATABLOCKSIZE" - units = expression.args.get("units") - units = f" {units}" if units else "" - return f"DATABLOCKSIZE={self.sql(expression, 'size')}{units}" - - def blockcompressionproperty_sql(self, expression: exp.BlockCompressionProperty) -> str: - autotemp = expression.args.get("autotemp") - always = expression.args.get("always") - default = expression.args.get("default") - manual = expression.args.get("manual") - never = expression.args.get("never") - - if autotemp is not None: - prop = f"AUTOTEMP({self.expressions(autotemp)})" - elif always: - prop = "ALWAYS" - elif default: - prop = "DEFAULT" - elif manual: - prop = "MANUAL" - elif never: - prop = "NEVER" - return f"BLOCKCOMPRESSION={prop}" - - def isolatedloadingproperty_sql(self, expression: exp.IsolatedLoadingProperty) -> str: - no = expression.args.get("no") - no = " NO" if no else "" - concurrent = expression.args.get("concurrent") - concurrent = " CONCURRENT" if concurrent else "" - target = self.sql(expression, "target") - target = f" {target}" if target else "" - return f"WITH{no}{concurrent} ISOLATED LOADING{target}" - - def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str: - if isinstance(expression.this, list): - return f"IN ({self.expressions(expression, key='this', flat=True)})" - if expression.this: - modulus = self.sql(expression, "this") - remainder = self.sql(expression, "expression") - return f"WITH (MODULUS {modulus}, REMAINDER {remainder})" - - from_expressions = self.expressions(expression, key="from_expressions", flat=True) - to_expressions = self.expressions(expression, key="to_expressions", flat=True) - return f"FROM ({from_expressions}) TO ({to_expressions})" - - def partitionedofproperty_sql(self, expression: exp.PartitionedOfProperty) -> str: - this = self.sql(expression, "this") - - for_values_or_default = expression.expression - if isinstance(for_values_or_default, exp.PartitionBoundSpec): - for_values_or_default = f" FOR VALUES {self.sql(for_values_or_default)}" - else: - for_values_or_default = " DEFAULT" - - return f"PARTITION OF {this}{for_values_or_default}" - - def lockingproperty_sql(self, expression: exp.LockingProperty) -> str: - kind = expression.args.get("kind") - this = f" {self.sql(expression, 'this')}" if expression.this else "" - for_or_in = expression.args.get("for_or_in") - for_or_in = f" {for_or_in}" if for_or_in else "" - lock_type = expression.args.get("lock_type") - override = " OVERRIDE" if expression.args.get("override") else "" - return f"LOCKING {kind}{this}{for_or_in} {lock_type}{override}" - - def withdataproperty_sql(self, expression: exp.WithDataProperty) -> str: - data_sql = f"WITH {'NO ' if expression.args.get('no') else ''}DATA" - statistics = expression.args.get("statistics") - statistics_sql = "" - if statistics is not None: - statistics_sql = f" AND {'NO ' if not statistics else ''}STATISTICS" - return f"{data_sql}{statistics_sql}" - - def withsystemversioningproperty_sql(self, expression: exp.WithSystemVersioningProperty) -> str: - this = self.sql(expression, "this") - this = f"HISTORY_TABLE={this}" if this else "" - data_consistency: t.Optional[str] = self.sql(expression, "data_consistency") - data_consistency = ( - f"DATA_CONSISTENCY_CHECK={data_consistency}" if data_consistency else None - ) - retention_period: t.Optional[str] = self.sql(expression, "retention_period") - retention_period = ( - f"HISTORY_RETENTION_PERIOD={retention_period}" if retention_period else None - ) - - if this: - on_sql = self.func("ON", this, data_consistency, retention_period) - else: - on_sql = "ON" if expression.args.get("on") else "OFF" - - sql = f"SYSTEM_VERSIONING={on_sql}" - - return f"WITH({sql})" if expression.args.get("with") else sql - - def insert_sql(self, expression: exp.Insert) -> str: - hint = self.sql(expression, "hint") - overwrite = expression.args.get("overwrite") - - if isinstance(expression.this, exp.Directory): - this = " OVERWRITE" if overwrite else " INTO" - else: - this = self.INSERT_OVERWRITE if overwrite else " INTO" - - stored = self.sql(expression, "stored") - stored = f" {stored}" if stored else "" - alternative = expression.args.get("alternative") - alternative = f" OR {alternative}" if alternative else "" - ignore = " IGNORE" if expression.args.get("ignore") else "" - is_function = expression.args.get("is_function") - if is_function: - this = f"{this} FUNCTION" - this = f"{this} {self.sql(expression, 'this')}" - - exists = " IF EXISTS" if expression.args.get("exists") else "" - where = self.sql(expression, "where") - where = f"{self.sep()}REPLACE WHERE {where}" if where else "" - expression_sql = f"{self.sep()}{self.sql(expression, 'expression')}" - on_conflict = self.sql(expression, "conflict") - on_conflict = f" {on_conflict}" if on_conflict else "" - by_name = " BY NAME" if expression.args.get("by_name") else "" - returning = self.sql(expression, "returning") - - if self.RETURNING_END: - expression_sql = f"{expression_sql}{on_conflict}{returning}" - else: - expression_sql = f"{returning}{expression_sql}{on_conflict}" - - partition_by = self.sql(expression, "partition") - partition_by = f" {partition_by}" if partition_by else "" - settings = self.sql(expression, "settings") - settings = f" {settings}" if settings else "" - - source = self.sql(expression, "source") - source = f"TABLE {source}" if source else "" - - sql = f"INSERT{hint}{alternative}{ignore}{this}{stored}{by_name}{exists}{partition_by}{settings}{where}{expression_sql}{source}" - return self.prepend_ctes(expression, sql) - - def introducer_sql(self, expression: exp.Introducer) -> str: - return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" - - def kill_sql(self, expression: exp.Kill) -> str: - kind = self.sql(expression, "kind") - kind = f" {kind}" if kind else "" - this = self.sql(expression, "this") - this = f" {this}" if this else "" - return f"KILL{kind}{this}" - - def pseudotype_sql(self, expression: exp.PseudoType) -> str: - return expression.name - - def objectidentifier_sql(self, expression: exp.ObjectIdentifier) -> str: - return expression.name - - def onconflict_sql(self, expression: exp.OnConflict) -> str: - conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT" - - constraint = self.sql(expression, "constraint") - constraint = f" ON CONSTRAINT {constraint}" if constraint else "" - - conflict_keys = self.expressions(expression, key="conflict_keys", flat=True) - conflict_keys = f"({conflict_keys}) " if conflict_keys else " " - action = self.sql(expression, "action") - - expressions = self.expressions(expression, flat=True) - if expressions: - set_keyword = "SET " if self.DUPLICATE_KEY_UPDATE_WITH_SET else "" - expressions = f" {set_keyword}{expressions}" - - where = self.sql(expression, "where") - return f"{conflict}{constraint}{conflict_keys}{action}{expressions}{where}" - - def returning_sql(self, expression: exp.Returning) -> str: - return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}" - - def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str: - fields = self.sql(expression, "fields") - fields = f" FIELDS TERMINATED BY {fields}" if fields else "" - escaped = self.sql(expression, "escaped") - escaped = f" ESCAPED BY {escaped}" if escaped else "" - items = self.sql(expression, "collection_items") - items = f" COLLECTION ITEMS TERMINATED BY {items}" if items else "" - keys = self.sql(expression, "map_keys") - keys = f" MAP KEYS TERMINATED BY {keys}" if keys else "" - lines = self.sql(expression, "lines") - lines = f" LINES TERMINATED BY {lines}" if lines else "" - null = self.sql(expression, "null") - null = f" NULL DEFINED AS {null}" if null else "" - return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}" - - def withtablehint_sql(self, expression: exp.WithTableHint) -> str: - return f"WITH ({self.expressions(expression, flat=True)})" - - def indextablehint_sql(self, expression: exp.IndexTableHint) -> str: - this = f"{self.sql(expression, 'this')} INDEX" - target = self.sql(expression, "target") - target = f" FOR {target}" if target else "" - return f"{this}{target} ({self.expressions(expression, flat=True)})" - - def historicaldata_sql(self, expression: exp.HistoricalData) -> str: - this = self.sql(expression, "this") - kind = self.sql(expression, "kind") - expr = self.sql(expression, "expression") - return f"{this} ({kind} => {expr})" - - def table_parts(self, expression: exp.Table) -> str: - return ".".join( - self.sql(part) - for part in ( - expression.args.get("catalog"), - expression.args.get("db"), - expression.args.get("this"), - ) - if part is not None - ) - - def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str: - table = self.table_parts(expression) - only = "ONLY " if expression.args.get("only") else "" - partition = self.sql(expression, "partition") - partition = f" {partition}" if partition else "" - version = self.sql(expression, "version") - version = f" {version}" if version else "" - alias = self.sql(expression, "alias") - alias = f"{sep}{alias}" if alias else "" - - sample = self.sql(expression, "sample") - if self.dialect.ALIAS_POST_TABLESAMPLE: - sample_pre_alias = sample - sample_post_alias = "" - else: - sample_pre_alias = "" - sample_post_alias = sample - - hints = self.expressions(expression, key="hints", sep=" ") - hints = f" {hints}" if hints and self.TABLE_HINTS else "" - pivots = self.expressions(expression, key="pivots", sep="", flat=True) - joins = self.indent( - self.expressions(expression, key="joins", sep="", flat=True), skip_first=True - ) - laterals = self.expressions(expression, key="laterals", sep="") - - file_format = self.sql(expression, "format") - if file_format: - pattern = self.sql(expression, "pattern") - pattern = f", PATTERN => {pattern}" if pattern else "" - file_format = f" (FILE_FORMAT => {file_format}{pattern})" - - ordinality = expression.args.get("ordinality") or "" - if ordinality: - ordinality = f" WITH ORDINALITY{alias}" - alias = "" - - when = self.sql(expression, "when") - if when: - table = f"{table} {when}" - - changes = self.sql(expression, "changes") - changes = f" {changes}" if changes else "" - - rows_from = self.expressions(expression, key="rows_from") - if rows_from: - table = f"ROWS FROM {self.wrap(rows_from)}" - - return f"{only}{table}{changes}{partition}{version}{file_format}{sample_pre_alias}{alias}{hints}{pivots}{sample_post_alias}{joins}{laterals}{ordinality}" - - def tablefromrows_sql(self, expression: exp.TableFromRows) -> str: - table = self.func("TABLE", expression.this) - alias = self.sql(expression, "alias") - alias = f" AS {alias}" if alias else "" - sample = self.sql(expression, "sample") - pivots = self.expressions(expression, key="pivots", sep="", flat=True) - joins = self.indent( - self.expressions(expression, key="joins", sep="", flat=True), skip_first=True - ) - return f"{table}{alias}{pivots}{sample}{joins}" - - def tablesample_sql( - self, - expression: exp.TableSample, - tablesample_keyword: t.Optional[str] = None, - ) -> str: - method = self.sql(expression, "method") - method = f"{method} " if method and self.TABLESAMPLE_WITH_METHOD else "" - numerator = self.sql(expression, "bucket_numerator") - denominator = self.sql(expression, "bucket_denominator") - field = self.sql(expression, "bucket_field") - field = f" ON {field}" if field else "" - bucket = f"BUCKET {numerator} OUT OF {denominator}{field}" if numerator else "" - seed = self.sql(expression, "seed") - seed = f" {self.TABLESAMPLE_SEED_KEYWORD} ({seed})" if seed else "" - - size = self.sql(expression, "size") - if size and self.TABLESAMPLE_SIZE_IS_ROWS: - size = f"{size} ROWS" - - percent = self.sql(expression, "percent") - if percent and not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT: - percent = f"{percent} PERCENT" - - expr = f"{bucket}{percent}{size}" - if self.TABLESAMPLE_REQUIRES_PARENS: - expr = f"({expr})" - - return f" {tablesample_keyword or self.TABLESAMPLE_KEYWORDS} {method}{expr}{seed}" - - def pivot_sql(self, expression: exp.Pivot) -> str: - expressions = self.expressions(expression, flat=True) - direction = "UNPIVOT" if expression.unpivot else "PIVOT" - - group = self.sql(expression, "group") - - if expression.this: - this = self.sql(expression, "this") - if not expressions: - return f"UNPIVOT {this}" - - on = f"{self.seg('ON')} {expressions}" - into = self.sql(expression, "into") - into = f"{self.seg('INTO')} {into}" if into else "" - using = self.expressions(expression, key="using", flat=True) - using = f"{self.seg('USING')} {using}" if using else "" - return f"{direction} {this}{on}{into}{using}{group}" - - alias = self.sql(expression, "alias") - alias = f" AS {alias}" if alias else "" - - fields = self.expressions( - expression, - "fields", - sep=" ", - dynamic=True, - new_line=True, - skip_first=True, - skip_last=True, - ) - - include_nulls = expression.args.get("include_nulls") - if include_nulls is not None: - nulls = " INCLUDE NULLS " if include_nulls else " EXCLUDE NULLS " - else: - nulls = "" - - default_on_null = self.sql(expression, "default_on_null") - default_on_null = f" DEFAULT ON NULL ({default_on_null})" if default_on_null else "" - return f"{self.seg(direction)}{nulls}({expressions} FOR {fields}{default_on_null}{group}){alias}" - - def version_sql(self, expression: exp.Version) -> str: - this = f"FOR {expression.name}" - kind = expression.text("kind") - expr = self.sql(expression, "expression") - return f"{this} {kind} {expr}" - - def tuple_sql(self, expression: exp.Tuple) -> str: - return f"({self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)})" - - def update_sql(self, expression: exp.Update) -> str: - this = self.sql(expression, "this") - set_sql = self.expressions(expression, flat=True) - from_sql = self.sql(expression, "from") - where_sql = self.sql(expression, "where") - returning = self.sql(expression, "returning") - order = self.sql(expression, "order") - limit = self.sql(expression, "limit") - if self.RETURNING_END: - expression_sql = f"{from_sql}{where_sql}{returning}" - else: - expression_sql = f"{returning}{from_sql}{where_sql}" - sql = f"UPDATE {this} SET {set_sql}{expression_sql}{order}{limit}" - return self.prepend_ctes(expression, sql) - - def values_sql(self, expression: exp.Values, values_as_table: bool = True) -> str: - values_as_table = values_as_table and self.VALUES_AS_TABLE - - # The VALUES clause is still valid in an `INSERT INTO ..` statement, for example - if values_as_table or not expression.find_ancestor(exp.From, exp.Join): - args = self.expressions(expression) - alias = self.sql(expression, "alias") - values = f"VALUES{self.seg('')}{args}" - values = ( - f"({values})" - if self.WRAP_DERIVED_VALUES - and (alias or isinstance(expression.parent, (exp.From, exp.Table))) - else values - ) - return f"{values} AS {alias}" if alias else values - - # Converts `VALUES...` expression into a series of select unions. - alias_node = expression.args.get("alias") - column_names = alias_node and alias_node.columns - - selects: t.List[exp.Query] = [] - - for i, tup in enumerate(expression.expressions): - row = tup.expressions - - if i == 0 and column_names: - row = [ - exp.alias_(value, column_name) for value, column_name in zip(row, column_names) - ] - - selects.append(exp.Select(expressions=row)) - - if self.pretty: - # This may result in poor performance for large-cardinality `VALUES` tables, due to - # the deep nesting of the resulting exp.Unions. If this is a problem, either increase - # `sys.setrecursionlimit` to avoid RecursionErrors, or don't set `pretty`. - query = reduce(lambda x, y: exp.union(x, y, distinct=False, copy=False), selects) - return self.subquery_sql(query.subquery(alias_node and alias_node.this, copy=False)) - - alias = f" AS {self.sql(alias_node, 'this')}" if alias_node else "" - unions = " UNION ALL ".join(self.sql(select) for select in selects) - return f"({unions}){alias}" - - def var_sql(self, expression: exp.Var) -> str: - return self.sql(expression, "this") - - @unsupported_args("expressions") - def into_sql(self, expression: exp.Into) -> str: - temporary = " TEMPORARY" if expression.args.get("temporary") else "" - unlogged = " UNLOGGED" if expression.args.get("unlogged") else "" - return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}" - - def from_sql(self, expression: exp.From) -> str: - return f"{self.seg('FROM')} {self.sql(expression, 'this')}" - - def groupingsets_sql(self, expression: exp.GroupingSets) -> str: - grouping_sets = self.expressions(expression, indent=False) - return f"GROUPING SETS {self.wrap(grouping_sets)}" - - def rollup_sql(self, expression: exp.Rollup) -> str: - expressions = self.expressions(expression, indent=False) - return f"ROLLUP {self.wrap(expressions)}" if expressions else "WITH ROLLUP" - - def cube_sql(self, expression: exp.Cube) -> str: - expressions = self.expressions(expression, indent=False) - return f"CUBE {self.wrap(expressions)}" if expressions else "WITH CUBE" - - def group_sql(self, expression: exp.Group) -> str: - group_by_all = expression.args.get("all") - if group_by_all is True: - modifier = " ALL" - elif group_by_all is False: - modifier = " DISTINCT" - else: - modifier = "" - - group_by = self.op_expressions(f"GROUP BY{modifier}", expression) - - grouping_sets = self.expressions(expression, key="grouping_sets") - cube = self.expressions(expression, key="cube") - rollup = self.expressions(expression, key="rollup") - - groupings = csv( - self.seg(grouping_sets) if grouping_sets else "", - self.seg(cube) if cube else "", - self.seg(rollup) if rollup else "", - self.seg("WITH TOTALS") if expression.args.get("totals") else "", - sep=self.GROUPINGS_SEP, - ) - - if ( - expression.expressions - and groupings - and groupings.strip() not in ("WITH CUBE", "WITH ROLLUP") - ): - group_by = f"{group_by}{self.GROUPINGS_SEP}" - - return f"{group_by}{groupings}" - - def having_sql(self, expression: exp.Having) -> str: - this = self.indent(self.sql(expression, "this")) - return f"{self.seg('HAVING')}{self.sep()}{this}" - - def connect_sql(self, expression: exp.Connect) -> str: - start = self.sql(expression, "start") - start = self.seg(f"START WITH {start}") if start else "" - nocycle = " NOCYCLE" if expression.args.get("nocycle") else "" - connect = self.sql(expression, "connect") - connect = self.seg(f"CONNECT BY{nocycle} {connect}") - return start + connect - - def prior_sql(self, expression: exp.Prior) -> str: - return f"PRIOR {self.sql(expression, 'this')}" - - def join_sql(self, expression: exp.Join) -> str: - if not self.SEMI_ANTI_JOIN_WITH_SIDE and expression.kind in ("SEMI", "ANTI"): - side = None - else: - side = expression.side - - op_sql = " ".join( - op - for op in ( - expression.method, - "GLOBAL" if expression.args.get("global") else None, - side, - expression.kind, - expression.hint if self.JOIN_HINTS else None, - ) - if op - ) - match_cond = self.sql(expression, "match_condition") - match_cond = f" MATCH_CONDITION ({match_cond})" if match_cond else "" - on_sql = self.sql(expression, "on") - using = expression.args.get("using") - - if not on_sql and using: - on_sql = csv(*(self.sql(column) for column in using)) - - this = expression.this - this_sql = self.sql(this) - - exprs = self.expressions(expression) - if exprs: - this_sql = f"{this_sql},{self.seg(exprs)}" - - if on_sql: - on_sql = self.indent(on_sql, skip_first=True) - space = self.seg(" " * self.pad) if self.pretty else " " - if using: - on_sql = f"{space}USING ({on_sql})" - else: - on_sql = f"{space}ON {on_sql}" - elif not op_sql: - if isinstance(this, exp.Lateral) and this.args.get("cross_apply") is not None: - return f" {this_sql}" - - return f", {this_sql}" - - if op_sql != "STRAIGHT_JOIN": - op_sql = f"{op_sql} JOIN" if op_sql else "JOIN" - - pivots = self.expressions(expression, key="pivots", sep="", flat=True) - return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}{pivots}" - - def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str: - args = self.expressions(expression, flat=True) - args = f"({args})" if len(args.split(",")) > 1 else args - return f"{args} {arrow_sep} {self.sql(expression, 'this')}" - - def lateral_op(self, expression: exp.Lateral) -> str: - cross_apply = expression.args.get("cross_apply") - - # https://www.mssqltips.com/sqlservertip/1958/sql-server-cross-apply-and-outer-apply/ - if cross_apply is True: - op = "INNER JOIN " - elif cross_apply is False: - op = "LEFT JOIN " - else: - op = "" - - return f"{op}LATERAL" - - def lateral_sql(self, expression: exp.Lateral) -> str: - this = self.sql(expression, "this") - - if expression.args.get("view"): - alias = expression.args["alias"] - columns = self.expressions(alias, key="columns", flat=True) - table = f" {alias.name}" if alias.name else "" - columns = f" AS {columns}" if columns else "" - op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") - return f"{op_sql}{self.sep()}{this}{table}{columns}" - - alias = self.sql(expression, "alias") - alias = f" AS {alias}" if alias else "" - - ordinality = expression.args.get("ordinality") or "" - if ordinality: - ordinality = f" WITH ORDINALITY{alias}" - alias = "" - - return f"{self.lateral_op(expression)} {this}{alias}{ordinality}" - - def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: - this = self.sql(expression, "this") - - args = [ - self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e - for e in (expression.args.get(k) for k in ("offset", "expression")) - if e - ] - - args_sql = ", ".join(self.sql(e) for e in args) - args_sql = f"({args_sql})" if top and any(not e.is_number for e in args) else args_sql - expressions = self.expressions(expression, flat=True) - limit_options = self.sql(expression, "limit_options") - expressions = f" BY {expressions}" if expressions else "" - - return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}{limit_options}{expressions}" - - def offset_sql(self, expression: exp.Offset) -> str: - this = self.sql(expression, "this") - value = expression.expression - value = self._simplify_unless_literal(value) if self.LIMIT_ONLY_LITERALS else value - expressions = self.expressions(expression, flat=True) - expressions = f" BY {expressions}" if expressions else "" - return f"{this}{self.seg('OFFSET')} {self.sql(value)}{expressions}" - - def setitem_sql(self, expression: exp.SetItem) -> str: - kind = self.sql(expression, "kind") - kind = f"{kind} " if kind else "" - this = self.sql(expression, "this") - expressions = self.expressions(expression) - collate = self.sql(expression, "collate") - collate = f" COLLATE {collate}" if collate else "" - global_ = "GLOBAL " if expression.args.get("global") else "" - return f"{global_}{kind}{this}{expressions}{collate}" - - def set_sql(self, expression: exp.Set) -> str: - expressions = f" {self.expressions(expression, flat=True)}" - tag = " TAG" if expression.args.get("tag") else "" - return f"{'UNSET' if expression.args.get('unset') else 'SET'}{tag}{expressions}" - - def pragma_sql(self, expression: exp.Pragma) -> str: - return f"PRAGMA {self.sql(expression, 'this')}" - - def lock_sql(self, expression: exp.Lock) -> str: - if not self.LOCKING_READS_SUPPORTED: - self.unsupported("Locking reads using 'FOR UPDATE/SHARE' are not supported") - return "" - - lock_type = "FOR UPDATE" if expression.args["update"] else "FOR SHARE" - expressions = self.expressions(expression, flat=True) - expressions = f" OF {expressions}" if expressions else "" - wait = expression.args.get("wait") - - if wait is not None: - if isinstance(wait, exp.Literal): - wait = f" WAIT {self.sql(wait)}" - else: - wait = " NOWAIT" if wait else " SKIP LOCKED" - - return f"{lock_type}{expressions}{wait or ''}" - - def literal_sql(self, expression: exp.Literal) -> str: - text = expression.this or "" - if expression.is_string: - text = f"{self.dialect.QUOTE_START}{self.escape_str(text)}{self.dialect.QUOTE_END}" - return text - - def escape_str(self, text: str, escape_backslash: bool = True) -> str: - if self.dialect.ESCAPED_SEQUENCES: - to_escaped = self.dialect.ESCAPED_SEQUENCES - text = "".join( - to_escaped.get(ch, ch) if escape_backslash or ch != "\\" else ch for ch in text - ) - - return self._replace_line_breaks(text).replace( - self.dialect.QUOTE_END, self._escaped_quote_end - ) - - def loaddata_sql(self, expression: exp.LoadData) -> str: - local = " LOCAL" if expression.args.get("local") else "" - inpath = f" INPATH {self.sql(expression, 'inpath')}" - overwrite = " OVERWRITE" if expression.args.get("overwrite") else "" - this = f" INTO TABLE {self.sql(expression, 'this')}" - partition = self.sql(expression, "partition") - partition = f" {partition}" if partition else "" - input_format = self.sql(expression, "input_format") - input_format = f" INPUTFORMAT {input_format}" if input_format else "" - serde = self.sql(expression, "serde") - serde = f" SERDE {serde}" if serde else "" - return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}" - - def null_sql(self, *_) -> str: - return "NULL" - - def boolean_sql(self, expression: exp.Boolean) -> str: - return "TRUE" if expression.this else "FALSE" - - def order_sql(self, expression: exp.Order, flat: bool = False) -> str: - this = self.sql(expression, "this") - this = f"{this} " if this else this - siblings = "SIBLINGS " if expression.args.get("siblings") else "" - return self.op_expressions(f"{this}ORDER {siblings}BY", expression, flat=this or flat) # type: ignore - - def withfill_sql(self, expression: exp.WithFill) -> str: - from_sql = self.sql(expression, "from") - from_sql = f" FROM {from_sql}" if from_sql else "" - to_sql = self.sql(expression, "to") - to_sql = f" TO {to_sql}" if to_sql else "" - step_sql = self.sql(expression, "step") - step_sql = f" STEP {step_sql}" if step_sql else "" - interpolated_values = [ - f"{self.sql(e, 'alias')} AS {self.sql(e, 'this')}" - if isinstance(e, exp.Alias) - else self.sql(e, "this") - for e in expression.args.get("interpolate") or [] - ] - interpolate = ( - f" INTERPOLATE ({', '.join(interpolated_values)})" if interpolated_values else "" - ) - return f"WITH FILL{from_sql}{to_sql}{step_sql}{interpolate}" - - def cluster_sql(self, expression: exp.Cluster) -> str: - return self.op_expressions("CLUSTER BY", expression) - - def distribute_sql(self, expression: exp.Distribute) -> str: - return self.op_expressions("DISTRIBUTE BY", expression) - - def sort_sql(self, expression: exp.Sort) -> str: - return self.op_expressions("SORT BY", expression) - - def ordered_sql(self, expression: exp.Ordered) -> str: - desc = expression.args.get("desc") - asc = not desc - - nulls_first = expression.args.get("nulls_first") - nulls_last = not nulls_first - nulls_are_large = self.dialect.NULL_ORDERING == "nulls_are_large" - nulls_are_small = self.dialect.NULL_ORDERING == "nulls_are_small" - nulls_are_last = self.dialect.NULL_ORDERING == "nulls_are_last" - - this = self.sql(expression, "this") - - sort_order = " DESC" if desc else (" ASC" if desc is False else "") - nulls_sort_change = "" - if nulls_first and ( - (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last - ): - nulls_sort_change = " NULLS FIRST" - elif ( - nulls_last - and ((asc and nulls_are_small) or (desc and nulls_are_large)) - and not nulls_are_last - ): - nulls_sort_change = " NULLS LAST" - - # If the NULLS FIRST/LAST clause is unsupported, we add another sort key to simulate it - if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: - window = expression.find_ancestor(exp.Window, exp.Select) - if isinstance(window, exp.Window) and window.args.get("spec"): - self.unsupported( - f"'{nulls_sort_change.strip()}' translation not supported in window functions" - ) - nulls_sort_change = "" - elif self.NULL_ORDERING_SUPPORTED is False and ( - (asc and nulls_sort_change == " NULLS LAST") - or (desc and nulls_sort_change == " NULLS FIRST") - ): - # BigQuery does not allow these ordering/nulls combinations when used under - # an aggregation func or under a window containing one - ancestor = expression.find_ancestor(exp.AggFunc, exp.Window, exp.Select) - - if isinstance(ancestor, exp.Window): - ancestor = ancestor.this - if isinstance(ancestor, exp.AggFunc): - self.unsupported( - f"'{nulls_sort_change.strip()}' translation not supported for aggregate functions with {sort_order} sort order" - ) - nulls_sort_change = "" - elif self.NULL_ORDERING_SUPPORTED is None: - if expression.this.is_int: - self.unsupported( - f"'{nulls_sort_change.strip()}' translation not supported with positional ordering" - ) - elif not isinstance(expression.this, exp.Rand): - null_sort_order = " DESC" if nulls_sort_change == " NULLS FIRST" else "" - this = f"CASE WHEN {this} IS NULL THEN 1 ELSE 0 END{null_sort_order}, {this}" - nulls_sort_change = "" - - with_fill = self.sql(expression, "with_fill") - with_fill = f" {with_fill}" if with_fill else "" - - return f"{this}{sort_order}{nulls_sort_change}{with_fill}" - - def matchrecognizemeasure_sql(self, expression: exp.MatchRecognizeMeasure) -> str: - window_frame = self.sql(expression, "window_frame") - window_frame = f"{window_frame} " if window_frame else "" - - this = self.sql(expression, "this") - - return f"{window_frame}{this}" - - def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str: - partition = self.partition_by_sql(expression) - order = self.sql(expression, "order") - measures = self.expressions(expression, key="measures") - measures = self.seg(f"MEASURES{self.seg(measures)}") if measures else "" - rows = self.sql(expression, "rows") - rows = self.seg(rows) if rows else "" - after = self.sql(expression, "after") - after = self.seg(after) if after else "" - pattern = self.sql(expression, "pattern") - pattern = self.seg(f"PATTERN ({pattern})") if pattern else "" - definition_sqls = [ - f"{self.sql(definition, 'alias')} AS {self.sql(definition, 'this')}" - for definition in expression.args.get("define", []) - ] - definitions = self.expressions(sqls=definition_sqls) - define = self.seg(f"DEFINE{self.seg(definitions)}") if definitions else "" - body = "".join( - ( - partition, - order, - measures, - rows, - after, - pattern, - define, - ) - ) - alias = self.sql(expression, "alias") - alias = f" {alias}" if alias else "" - return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}" - - def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: - limit = expression.args.get("limit") - - if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch): - limit = exp.Limit(expression=exp.maybe_copy(limit.args.get("count"))) - elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit): - limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression)) - - return csv( - *sqls, - *[self.sql(join) for join in expression.args.get("joins") or []], - self.sql(expression, "match"), - *[self.sql(lateral) for lateral in expression.args.get("laterals") or []], - self.sql(expression, "prewhere"), - self.sql(expression, "where"), - self.sql(expression, "connect"), - self.sql(expression, "group"), - self.sql(expression, "having"), - *[gen(self, expression) for gen in self.AFTER_HAVING_MODIFIER_TRANSFORMS.values()], - self.sql(expression, "order"), - *self.offset_limit_modifiers(expression, isinstance(limit, exp.Fetch), limit), - *self.after_limit_modifiers(expression), - self.options_modifier(expression), - self.for_modifiers(expression), - sep="", - ) - - def options_modifier(self, expression: exp.Expression) -> str: - options = self.expressions(expression, key="options") - return f" {options}" if options else "" - - def for_modifiers(self, expression: exp.Expression) -> str: - for_modifiers = self.expressions(expression, key="for") - return f"{self.sep()}FOR XML{self.seg(for_modifiers)}" if for_modifiers else "" - - def queryoption_sql(self, expression: exp.QueryOption) -> str: - self.unsupported("Unsupported query option.") - return "" - - def offset_limit_modifiers( - self, expression: exp.Expression, fetch: bool, limit: t.Optional[exp.Fetch | exp.Limit] - ) -> t.List[str]: - return [ - self.sql(expression, "offset") if fetch else self.sql(limit), - self.sql(limit) if fetch else self.sql(expression, "offset"), - ] - - def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]: - locks = self.expressions(expression, key="locks", sep=" ") - locks = f" {locks}" if locks else "" - return [locks, self.sql(expression, "sample")] - - def select_sql(self, expression: exp.Select) -> str: - into = expression.args.get("into") - if not self.SUPPORTS_SELECT_INTO and into: - into.pop() - - hint = self.sql(expression, "hint") - distinct = self.sql(expression, "distinct") - distinct = f" {distinct}" if distinct else "" - kind = self.sql(expression, "kind") - - limit = expression.args.get("limit") - if isinstance(limit, exp.Limit) and self.LIMIT_IS_TOP: - top = self.limit_sql(limit, top=True) - limit.pop() - else: - top = "" - - expressions = self.expressions(expression) - - if kind: - if kind in self.SELECT_KINDS: - kind = f" AS {kind}" - else: - if kind == "STRUCT": - expressions = self.expressions( - sqls=[ - self.sql( - exp.Struct( - expressions=[ - exp.PropertyEQ(this=e.args.get("alias"), expression=e.this) - if isinstance(e, exp.Alias) - else e - for e in expression.expressions - ] - ) - ) - ] - ) - kind = "" - - operation_modifiers = self.expressions(expression, key="operation_modifiers", sep=" ") - operation_modifiers = f"{self.sep()}{operation_modifiers}" if operation_modifiers else "" - - # We use LIMIT_IS_TOP as a proxy for whether DISTINCT should go first because tsql and Teradata - # are the only dialects that use LIMIT_IS_TOP and both place DISTINCT first. - top_distinct = f"{distinct}{hint}{top}" if self.LIMIT_IS_TOP else f"{top}{hint}{distinct}" - expressions = f"{self.sep()}{expressions}" if expressions else expressions - sql = self.query_modifiers( - expression, - f"SELECT{top_distinct}{operation_modifiers}{kind}{expressions}", - self.sql(expression, "into", comment=False), - self.sql(expression, "from", comment=False), - ) - - # If both the CTE and SELECT clauses have comments, generate the latter earlier - if expression.args.get("with"): - sql = self.maybe_comment(sql, expression) - expression.pop_comments() - - sql = self.prepend_ctes(expression, sql) - - if not self.SUPPORTS_SELECT_INTO and into: - if into.args.get("temporary"): - table_kind = " TEMPORARY" - elif self.SUPPORTS_UNLOGGED_TABLES and into.args.get("unlogged"): - table_kind = " UNLOGGED" - else: - table_kind = "" - sql = f"CREATE{table_kind} TABLE {self.sql(into.this)} AS {sql}" - - return sql - - def schema_sql(self, expression: exp.Schema) -> str: - this = self.sql(expression, "this") - sql = self.schema_columns_sql(expression) - return f"{this} {sql}" if this and sql else this or sql - - def schema_columns_sql(self, expression: exp.Schema) -> str: - if expression.expressions: - return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" - return "" - - def star_sql(self, expression: exp.Star) -> str: - except_ = self.expressions(expression, key="except", flat=True) - except_ = f"{self.seg(self.STAR_EXCEPT)} ({except_})" if except_ else "" - replace = self.expressions(expression, key="replace", flat=True) - replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" - rename = self.expressions(expression, key="rename", flat=True) - rename = f"{self.seg('RENAME')} ({rename})" if rename else "" - return f"*{except_}{replace}{rename}" - - def parameter_sql(self, expression: exp.Parameter) -> str: - this = self.sql(expression, "this") - return f"{self.PARAMETER_TOKEN}{this}" - - def sessionparameter_sql(self, expression: exp.SessionParameter) -> str: - this = self.sql(expression, "this") - kind = expression.text("kind") - if kind: - kind = f"{kind}." - return f"@@{kind}{this}" - - def placeholder_sql(self, expression: exp.Placeholder) -> str: - return f"{self.NAMED_PLACEHOLDER_TOKEN}{expression.name}" if expression.this else "?" - - def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str: - alias = self.sql(expression, "alias") - alias = f"{sep}{alias}" if alias else "" - sample = self.sql(expression, "sample") - if self.dialect.ALIAS_POST_TABLESAMPLE and sample: - alias = f"{sample}{alias}" - - # Set to None so it's not generated again by self.query_modifiers() - expression.set("sample", None) - - pivots = self.expressions(expression, key="pivots", sep="", flat=True) - sql = self.query_modifiers(expression, self.wrap(expression), alias, pivots) - return self.prepend_ctes(expression, sql) - - def qualify_sql(self, expression: exp.Qualify) -> str: - this = self.indent(self.sql(expression, "this")) - return f"{self.seg('QUALIFY')}{self.sep()}{this}" - - def unnest_sql(self, expression: exp.Unnest) -> str: - args = self.expressions(expression, flat=True) - - alias = expression.args.get("alias") - offset = expression.args.get("offset") - - if self.UNNEST_WITH_ORDINALITY: - if alias and isinstance(offset, exp.Expression): - alias.append("columns", offset) - - if alias and self.dialect.UNNEST_COLUMN_ONLY: - columns = alias.columns - alias = self.sql(columns[0]) if columns else "" - else: - alias = self.sql(alias) - - alias = f" AS {alias}" if alias else alias - if self.UNNEST_WITH_ORDINALITY: - suffix = f" WITH ORDINALITY{alias}" if offset else alias - else: - if isinstance(offset, exp.Expression): - suffix = f"{alias} WITH OFFSET AS {self.sql(offset)}" - elif offset: - suffix = f"{alias} WITH OFFSET" - else: - suffix = alias - - return f"UNNEST({args}){suffix}" - - def prewhere_sql(self, expression: exp.PreWhere) -> str: - return "" - - def where_sql(self, expression: exp.Where) -> str: - this = self.indent(self.sql(expression, "this")) - return f"{self.seg('WHERE')}{self.sep()}{this}" - - def window_sql(self, expression: exp.Window) -> str: - this = self.sql(expression, "this") - partition = self.partition_by_sql(expression) - order = expression.args.get("order") - order = self.order_sql(order, flat=True) if order else "" - spec = self.sql(expression, "spec") - alias = self.sql(expression, "alias") - over = self.sql(expression, "over") or "OVER" - - this = f"{this} {'AS' if expression.arg_key == 'windows' else over}" - - first = expression.args.get("first") - if first is None: - first = "" - else: - first = "FIRST" if first else "LAST" - - if not partition and not order and not spec and alias: - return f"{this} {alias}" - - args = self.format_args( - *[arg for arg in (alias, first, partition, order, spec) if arg], sep=" " - ) - return f"{this} ({args})" - - def partition_by_sql(self, expression: exp.Window | exp.MatchRecognize) -> str: - partition = self.expressions(expression, key="partition_by", flat=True) - return f"PARTITION BY {partition}" if partition else "" - - def windowspec_sql(self, expression: exp.WindowSpec) -> str: - kind = self.sql(expression, "kind") - start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ") - end = ( - csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") - or "CURRENT ROW" - ) - - window_spec = f"{kind} BETWEEN {start} AND {end}" - - exclude = self.sql(expression, "exclude") - if exclude: - if self.SUPPORTS_WINDOW_EXCLUDE: - window_spec += f" EXCLUDE {exclude}" - else: - self.unsupported("EXCLUDE clause is not supported in the WINDOW clause") - - return window_spec - - def withingroup_sql(self, expression: exp.WithinGroup) -> str: - this = self.sql(expression, "this") - expression_sql = self.sql(expression, "expression")[1:] # order has a leading space - return f"{this} WITHIN GROUP ({expression_sql})" - - def between_sql(self, expression: exp.Between) -> str: - this = self.sql(expression, "this") - low = self.sql(expression, "low") - high = self.sql(expression, "high") - return f"{this} BETWEEN {low} AND {high}" - - def bracket_offset_expressions( - self, expression: exp.Bracket, index_offset: t.Optional[int] = None - ) -> t.List[exp.Expression]: - return apply_index_offset( - expression.this, - expression.expressions, - (index_offset or self.dialect.INDEX_OFFSET) - expression.args.get("offset", 0), - dialect=self.dialect, - ) - - def bracket_sql(self, expression: exp.Bracket) -> str: - expressions = self.bracket_offset_expressions(expression) - expressions_sql = ", ".join(self.sql(e) for e in expressions) - return f"{self.sql(expression, 'this')}[{expressions_sql}]" - - def all_sql(self, expression: exp.All) -> str: - return f"ALL {self.wrap(expression)}" - - def any_sql(self, expression: exp.Any) -> str: - this = self.sql(expression, "this") - if isinstance(expression.this, (*exp.UNWRAPPED_QUERIES, exp.Paren)): - if isinstance(expression.this, exp.UNWRAPPED_QUERIES): - this = self.wrap(this) - return f"ANY{this}" - return f"ANY {this}" - - def exists_sql(self, expression: exp.Exists) -> str: - return f"EXISTS{self.wrap(expression)}" - - def case_sql(self, expression: exp.Case) -> str: - this = self.sql(expression, "this") - statements = [f"CASE {this}" if this else "CASE"] - - for e in expression.args["ifs"]: - statements.append(f"WHEN {self.sql(e, 'this')}") - statements.append(f"THEN {self.sql(e, 'true')}") - - default = self.sql(expression, "default") - - if default: - statements.append(f"ELSE {default}") - - statements.append("END") - - if self.pretty and self.too_wide(statements): - return self.indent("\n".join(statements), skip_first=True, skip_last=True) - - return " ".join(statements) - - def constraint_sql(self, expression: exp.Constraint) -> str: - this = self.sql(expression, "this") - expressions = self.expressions(expression, flat=True) - return f"CONSTRAINT {this} {expressions}" - - def nextvaluefor_sql(self, expression: exp.NextValueFor) -> str: - order = expression.args.get("order") - order = f" OVER ({self.order_sql(order, flat=True)})" if order else "" - return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}" - - def extract_sql(self, expression: exp.Extract) -> str: - this = self.sql(expression, "this") if self.EXTRACT_ALLOWS_QUOTES else expression.this.name - expression_sql = self.sql(expression, "expression") - return f"EXTRACT({this} FROM {expression_sql})" - - def trim_sql(self, expression: exp.Trim) -> str: - trim_type = self.sql(expression, "position") - - if trim_type == "LEADING": - func_name = "LTRIM" - elif trim_type == "TRAILING": - func_name = "RTRIM" - else: - func_name = "TRIM" - - return self.func(func_name, expression.this, expression.expression) - - def convert_concat_args(self, expression: exp.Concat | exp.ConcatWs) -> t.List[exp.Expression]: - args = expression.expressions - if isinstance(expression, exp.ConcatWs): - args = args[1:] # Skip the delimiter - - if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"): - args = [exp.cast(e, exp.DataType.Type.TEXT) for e in args] - - if not self.dialect.CONCAT_COALESCE and expression.args.get("coalesce"): - args = [exp.func("coalesce", e, exp.Literal.string("")) for e in args] - - return args - - def concat_sql(self, expression: exp.Concat) -> str: - expressions = self.convert_concat_args(expression) - - # Some dialects don't allow a single-argument CONCAT call - if not self.SUPPORTS_SINGLE_ARG_CONCAT and len(expressions) == 1: - return self.sql(expressions[0]) - - return self.func("CONCAT", *expressions) - - def concatws_sql(self, expression: exp.ConcatWs) -> str: - return self.func( - "CONCAT_WS", seq_get(expression.expressions, 0), *self.convert_concat_args(expression) - ) - - def check_sql(self, expression: exp.Check) -> str: - this = self.sql(expression, key="this") - return f"CHECK ({this})" - - def foreignkey_sql(self, expression: exp.ForeignKey) -> str: - expressions = self.expressions(expression, flat=True) - expressions = f" ({expressions})" if expressions else "" - reference = self.sql(expression, "reference") - reference = f" {reference}" if reference else "" - delete = self.sql(expression, "delete") - delete = f" ON DELETE {delete}" if delete else "" - update = self.sql(expression, "update") - update = f" ON UPDATE {update}" if update else "" - options = self.expressions(expression, key="options", flat=True, sep=" ") - options = f" {options}" if options else "" - return f"FOREIGN KEY{expressions}{reference}{delete}{update}{options}" - - def primarykey_sql(self, expression: exp.ForeignKey) -> str: - expressions = self.expressions(expression, flat=True) - options = self.expressions(expression, key="options", flat=True, sep=" ") - options = f" {options}" if options else "" - return f"PRIMARY KEY ({expressions}){options}" - - def if_sql(self, expression: exp.If) -> str: - return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) - - def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: - modifier = expression.args.get("modifier") - modifier = f" {modifier}" if modifier else "" - return f"{self.func('MATCH', *expression.expressions)} AGAINST({self.sql(expression, 'this')}{modifier})" - - def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str: - return f"{self.sql(expression, 'this')}{self.JSON_KEY_VALUE_PAIR_SEP} {self.sql(expression, 'expression')}" - - def jsonpath_sql(self, expression: exp.JSONPath) -> str: - path = self.expressions(expression, sep="", flat=True).lstrip(".") - - if expression.args.get("escape"): - path = self.escape_str(path) - - if self.QUOTE_JSON_PATH: - path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" - - return path - - def json_path_part(self, expression: int | str | exp.JSONPathPart) -> str: - if isinstance(expression, exp.JSONPathPart): - transform = self.TRANSFORMS.get(expression.__class__) - if not callable(transform): - self.unsupported(f"Unsupported JSONPathPart type {expression.__class__.__name__}") - return "" - - return transform(self, expression) - - if isinstance(expression, int): - return str(expression) - - if self._quote_json_path_key_using_brackets and self.JSON_PATH_SINGLE_QUOTE_ESCAPE: - escaped = expression.replace("'", "\\'") - escaped = f"\\'{expression}\\'" - else: - escaped = expression.replace('"', '\\"') - escaped = f'"{escaped}"' - - return escaped - - def formatjson_sql(self, expression: exp.FormatJson) -> str: - return f"{self.sql(expression, 'this')} FORMAT JSON" - - def jsonobject_sql(self, expression: exp.JSONObject | exp.JSONObjectAgg) -> str: - null_handling = expression.args.get("null_handling") - null_handling = f" {null_handling}" if null_handling else "" - - unique_keys = expression.args.get("unique_keys") - if unique_keys is not None: - unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS" - else: - unique_keys = "" - - return_type = self.sql(expression, "return_type") - return_type = f" RETURNING {return_type}" if return_type else "" - encoding = self.sql(expression, "encoding") - encoding = f" ENCODING {encoding}" if encoding else "" - - return self.func( - "JSON_OBJECT" if isinstance(expression, exp.JSONObject) else "JSON_OBJECTAGG", - *expression.expressions, - suffix=f"{null_handling}{unique_keys}{return_type}{encoding})", - ) - - def jsonobjectagg_sql(self, expression: exp.JSONObjectAgg) -> str: - return self.jsonobject_sql(expression) - - def jsonarray_sql(self, expression: exp.JSONArray) -> str: - null_handling = expression.args.get("null_handling") - null_handling = f" {null_handling}" if null_handling else "" - return_type = self.sql(expression, "return_type") - return_type = f" RETURNING {return_type}" if return_type else "" - strict = " STRICT" if expression.args.get("strict") else "" - return self.func( - "JSON_ARRAY", *expression.expressions, suffix=f"{null_handling}{return_type}{strict})" - ) - - def jsonarrayagg_sql(self, expression: exp.JSONArrayAgg) -> str: - this = self.sql(expression, "this") - order = self.sql(expression, "order") - null_handling = expression.args.get("null_handling") - null_handling = f" {null_handling}" if null_handling else "" - return_type = self.sql(expression, "return_type") - return_type = f" RETURNING {return_type}" if return_type else "" - strict = " STRICT" if expression.args.get("strict") else "" - return self.func( - "JSON_ARRAYAGG", - this, - suffix=f"{order}{null_handling}{return_type}{strict})", - ) - - def jsoncolumndef_sql(self, expression: exp.JSONColumnDef) -> str: - path = self.sql(expression, "path") - path = f" PATH {path}" if path else "" - nested_schema = self.sql(expression, "nested_schema") - - if nested_schema: - return f"NESTED{path} {nested_schema}" - - this = self.sql(expression, "this") - kind = self.sql(expression, "kind") - kind = f" {kind}" if kind else "" - return f"{this}{kind}{path}" - - def jsonschema_sql(self, expression: exp.JSONSchema) -> str: - return self.func("COLUMNS", *expression.expressions) - - def jsontable_sql(self, expression: exp.JSONTable) -> str: - this = self.sql(expression, "this") - path = self.sql(expression, "path") - path = f", {path}" if path else "" - error_handling = expression.args.get("error_handling") - error_handling = f" {error_handling}" if error_handling else "" - empty_handling = expression.args.get("empty_handling") - empty_handling = f" {empty_handling}" if empty_handling else "" - schema = self.sql(expression, "schema") - return self.func( - "JSON_TABLE", this, suffix=f"{path}{error_handling}{empty_handling} {schema})" - ) - - def openjsoncolumndef_sql(self, expression: exp.OpenJSONColumnDef) -> str: - this = self.sql(expression, "this") - kind = self.sql(expression, "kind") - path = self.sql(expression, "path") - path = f" {path}" if path else "" - as_json = " AS JSON" if expression.args.get("as_json") else "" - return f"{this} {kind}{path}{as_json}" - - def openjson_sql(self, expression: exp.OpenJSON) -> str: - this = self.sql(expression, "this") - path = self.sql(expression, "path") - path = f", {path}" if path else "" - expressions = self.expressions(expression) - with_ = ( - f" WITH ({self.seg(self.indent(expressions), sep='')}{self.seg(')', sep='')}" - if expressions - else "" - ) - return f"OPENJSON({this}{path}){with_}" - - def in_sql(self, expression: exp.In) -> str: - query = expression.args.get("query") - unnest = expression.args.get("unnest") - field = expression.args.get("field") - is_global = " GLOBAL" if expression.args.get("is_global") else "" - - if query: - in_sql = self.sql(query) - elif unnest: - in_sql = self.in_unnest_op(unnest) - elif field: - in_sql = self.sql(field) - else: - in_sql = f"({self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)})" - - return f"{self.sql(expression, 'this')}{is_global} IN {in_sql}" - - def in_unnest_op(self, unnest: exp.Unnest) -> str: - return f"(SELECT {self.sql(unnest)})" - - def interval_sql(self, expression: exp.Interval) -> str: - unit = self.sql(expression, "unit") - if not self.INTERVAL_ALLOWS_PLURAL_FORM: - unit = self.TIME_PART_SINGULARS.get(unit, unit) - unit = f" {unit}" if unit else "" - - if self.SINGLE_STRING_INTERVAL: - this = expression.this.name if expression.this else "" - return f"INTERVAL '{this}{unit}'" if this else f"INTERVAL{unit}" - - this = self.sql(expression, "this") - if this: - unwrapped = isinstance(expression.this, self.UNWRAPPED_INTERVAL_VALUES) - this = f" {this}" if unwrapped else f" ({this})" - - return f"INTERVAL{this}{unit}" - - def return_sql(self, expression: exp.Return) -> str: - return f"RETURN {self.sql(expression, 'this')}" - - def reference_sql(self, expression: exp.Reference) -> str: - this = self.sql(expression, "this") - expressions = self.expressions(expression, flat=True) - expressions = f"({expressions})" if expressions else "" - options = self.expressions(expression, key="options", flat=True, sep=" ") - options = f" {options}" if options else "" - return f"REFERENCES {this}{expressions}{options}" - - def anonymous_sql(self, expression: exp.Anonymous) -> str: - # We don't normalize qualified functions such as a.b.foo(), because they can be case-sensitive - parent = expression.parent - is_qualified = isinstance(parent, exp.Dot) and expression is parent.expression - return self.func( - self.sql(expression, "this"), *expression.expressions, normalize=not is_qualified - ) - - def paren_sql(self, expression: exp.Paren) -> str: - sql = self.seg(self.indent(self.sql(expression, "this")), sep="") - return f"({sql}{self.seg(')', sep='')}" - - def neg_sql(self, expression: exp.Neg) -> str: - # This makes sure we don't convert "- - 5" to "--5", which is a comment - this_sql = self.sql(expression, "this") - sep = " " if this_sql[0] == "-" else "" - return f"-{sep}{this_sql}" - - def not_sql(self, expression: exp.Not) -> str: - return f"NOT {self.sql(expression, 'this')}" - - def alias_sql(self, expression: exp.Alias) -> str: - alias = self.sql(expression, "alias") - alias = f" AS {alias}" if alias else "" - return f"{self.sql(expression, 'this')}{alias}" - - def pivotalias_sql(self, expression: exp.PivotAlias) -> str: - alias = expression.args["alias"] - - parent = expression.parent - pivot = parent and parent.parent - - if isinstance(pivot, exp.Pivot) and pivot.unpivot: - identifier_alias = isinstance(alias, exp.Identifier) - literal_alias = isinstance(alias, exp.Literal) - - if identifier_alias and not self.UNPIVOT_ALIASES_ARE_IDENTIFIERS: - alias.replace(exp.Literal.string(alias.output_name)) - elif not identifier_alias and literal_alias and self.UNPIVOT_ALIASES_ARE_IDENTIFIERS: - alias.replace(exp.to_identifier(alias.output_name)) - - return self.alias_sql(expression) - - def aliases_sql(self, expression: exp.Aliases) -> str: - return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})" - - def atindex_sql(self, expression: exp.AtTimeZone) -> str: - this = self.sql(expression, "this") - index = self.sql(expression, "expression") - return f"{this} AT {index}" - - def attimezone_sql(self, expression: exp.AtTimeZone) -> str: - this = self.sql(expression, "this") - zone = self.sql(expression, "zone") - return f"{this} AT TIME ZONE {zone}" - - def fromtimezone_sql(self, expression: exp.FromTimeZone) -> str: - this = self.sql(expression, "this") - zone = self.sql(expression, "zone") - return f"{this} AT TIME ZONE {zone} AT TIME ZONE 'UTC'" - - def add_sql(self, expression: exp.Add) -> str: - return self.binary(expression, "+") - - def and_sql( - self, expression: exp.And, stack: t.Optional[t.List[str | exp.Expression]] = None - ) -> str: - return self.connector_sql(expression, "AND", stack) - - def or_sql( - self, expression: exp.Or, stack: t.Optional[t.List[str | exp.Expression]] = None - ) -> str: - return self.connector_sql(expression, "OR", stack) - - def xor_sql( - self, expression: exp.Xor, stack: t.Optional[t.List[str | exp.Expression]] = None - ) -> str: - return self.connector_sql(expression, "XOR", stack) - - def connector_sql( - self, - expression: exp.Connector, - op: str, - stack: t.Optional[t.List[str | exp.Expression]] = None, - ) -> str: - if stack is not None: - if expression.expressions: - stack.append(self.expressions(expression, sep=f" {op} ")) - else: - stack.append(expression.right) - if expression.comments and self.comments: - for comment in expression.comments: - if comment: - op += f" /*{self.sanitize_comment(comment)}*/" - stack.extend((op, expression.left)) - return op - - stack = [expression] - sqls: t.List[str] = [] - ops = set() - - while stack: - node = stack.pop() - if isinstance(node, exp.Connector): - ops.add(getattr(self, f"{node.key}_sql")(node, stack)) - else: - sql = self.sql(node) - if sqls and sqls[-1] in ops: - sqls[-1] += f" {sql}" - else: - sqls.append(sql) - - sep = "\n" if self.pretty and self.too_wide(sqls) else " " - return sep.join(sqls) - - def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str: - return self.binary(expression, "&") - - def bitwiseleftshift_sql(self, expression: exp.BitwiseLeftShift) -> str: - return self.binary(expression, "<<") - - def bitwisenot_sql(self, expression: exp.BitwiseNot) -> str: - return f"~{self.sql(expression, 'this')}" - - def bitwiseor_sql(self, expression: exp.BitwiseOr) -> str: - return self.binary(expression, "|") - - def bitwiserightshift_sql(self, expression: exp.BitwiseRightShift) -> str: - return self.binary(expression, ">>") - - def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str: - return self.binary(expression, "^") - - def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str: - format_sql = self.sql(expression, "format") - format_sql = f" FORMAT {format_sql}" if format_sql else "" - to_sql = self.sql(expression, "to") - to_sql = f" {to_sql}" if to_sql else "" - action = self.sql(expression, "action") - action = f" {action}" if action else "" - default = self.sql(expression, "default") - default = f" DEFAULT {default} ON CONVERSION ERROR" if default else "" - return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{default}{format_sql}{action})" - - def currentdate_sql(self, expression: exp.CurrentDate) -> str: - zone = self.sql(expression, "this") - return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE" - - def collate_sql(self, expression: exp.Collate) -> str: - if self.COLLATE_IS_FUNC: - return self.function_fallback_sql(expression) - return self.binary(expression, "COLLATE") - - def command_sql(self, expression: exp.Command) -> str: - return f"{self.sql(expression, 'this')} {expression.text('expression').strip()}" - - def comment_sql(self, expression: exp.Comment) -> str: - this = self.sql(expression, "this") - kind = expression.args["kind"] - materialized = " MATERIALIZED" if expression.args.get("materialized") else "" - exists_sql = " IF EXISTS " if expression.args.get("exists") else " " - expression_sql = self.sql(expression, "expression") - return f"COMMENT{exists_sql}ON{materialized} {kind} {this} IS {expression_sql}" - - def mergetreettlaction_sql(self, expression: exp.MergeTreeTTLAction) -> str: - this = self.sql(expression, "this") - delete = " DELETE" if expression.args.get("delete") else "" - recompress = self.sql(expression, "recompress") - recompress = f" RECOMPRESS {recompress}" if recompress else "" - to_disk = self.sql(expression, "to_disk") - to_disk = f" TO DISK {to_disk}" if to_disk else "" - to_volume = self.sql(expression, "to_volume") - to_volume = f" TO VOLUME {to_volume}" if to_volume else "" - return f"{this}{delete}{recompress}{to_disk}{to_volume}" - - def mergetreettl_sql(self, expression: exp.MergeTreeTTL) -> str: - where = self.sql(expression, "where") - group = self.sql(expression, "group") - aggregates = self.expressions(expression, key="aggregates") - aggregates = self.seg("SET") + self.seg(aggregates) if aggregates else "" - - if not (where or group or aggregates) and len(expression.expressions) == 1: - return f"TTL {self.expressions(expression, flat=True)}" - - return f"TTL{self.seg(self.expressions(expression))}{where}{group}{aggregates}" - - def transaction_sql(self, expression: exp.Transaction) -> str: - return "BEGIN" - - def commit_sql(self, expression: exp.Commit) -> str: - chain = expression.args.get("chain") - if chain is not None: - chain = " AND CHAIN" if chain else " AND NO CHAIN" - - return f"COMMIT{chain or ''}" - - def rollback_sql(self, expression: exp.Rollback) -> str: - savepoint = expression.args.get("savepoint") - savepoint = f" TO {savepoint}" if savepoint else "" - return f"ROLLBACK{savepoint}" - - def altercolumn_sql(self, expression: exp.AlterColumn) -> str: - this = self.sql(expression, "this") - - dtype = self.sql(expression, "dtype") - if dtype: - collate = self.sql(expression, "collate") - collate = f" COLLATE {collate}" if collate else "" - using = self.sql(expression, "using") - using = f" USING {using}" if using else "" - alter_set_type = self.ALTER_SET_TYPE + " " if self.ALTER_SET_TYPE else "" - return f"ALTER COLUMN {this} {alter_set_type}{dtype}{collate}{using}" - - default = self.sql(expression, "default") - if default: - return f"ALTER COLUMN {this} SET DEFAULT {default}" - - comment = self.sql(expression, "comment") - if comment: - return f"ALTER COLUMN {this} COMMENT {comment}" - - visible = expression.args.get("visible") - if visible: - return f"ALTER COLUMN {this} SET {visible}" - - allow_null = expression.args.get("allow_null") - drop = expression.args.get("drop") - - if not drop and not allow_null: - self.unsupported("Unsupported ALTER COLUMN syntax") - - if allow_null is not None: - keyword = "DROP" if drop else "SET" - return f"ALTER COLUMN {this} {keyword} NOT NULL" - - return f"ALTER COLUMN {this} DROP DEFAULT" - - def alterindex_sql(self, expression: exp.AlterIndex) -> str: - this = self.sql(expression, "this") - - visible = expression.args.get("visible") - visible_sql = "VISIBLE" if visible else "INVISIBLE" - - return f"ALTER INDEX {this} {visible_sql}" - - def alterdiststyle_sql(self, expression: exp.AlterDistStyle) -> str: - this = self.sql(expression, "this") - if not isinstance(expression.this, exp.Var): - this = f"KEY DISTKEY {this}" - return f"ALTER DISTSTYLE {this}" - - def altersortkey_sql(self, expression: exp.AlterSortKey) -> str: - compound = " COMPOUND" if expression.args.get("compound") else "" - this = self.sql(expression, "this") - expressions = self.expressions(expression, flat=True) - expressions = f"({expressions})" if expressions else "" - return f"ALTER{compound} SORTKEY {this or expressions}" - - def alterrename_sql(self, expression: exp.AlterRename) -> str: - if not self.RENAME_TABLE_WITH_DB: - # Remove db from tables - expression = expression.transform( - lambda n: exp.table_(n.this) if isinstance(n, exp.Table) else n - ).assert_is(exp.AlterRename) - this = self.sql(expression, "this") - return f"RENAME TO {this}" - - def renamecolumn_sql(self, expression: exp.RenameColumn) -> str: - exists = " IF EXISTS" if expression.args.get("exists") else "" - old_column = self.sql(expression, "this") - new_column = self.sql(expression, "to") - return f"RENAME COLUMN{exists} {old_column} TO {new_column}" - - def alterset_sql(self, expression: exp.AlterSet) -> str: - exprs = self.expressions(expression, flat=True) - if self.ALTER_SET_WRAPPED: - exprs = f"({exprs})" - - return f"SET {exprs}" - - def alter_sql(self, expression: exp.Alter) -> str: - actions = expression.args["actions"] - - if not self.dialect.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN and isinstance( - actions[0], exp.ColumnDef - ): - actions_sql = self.expressions(expression, key="actions", flat=True) - actions_sql = f"ADD {actions_sql}" - else: - actions_list = [] - for action in actions: - if isinstance(action, (exp.ColumnDef, exp.Schema)): - action_sql = self.add_column_sql(action) - else: - action_sql = self.sql(action) - if isinstance(action, exp.Query): - action_sql = f"AS {action_sql}" - - actions_list.append(action_sql) - - actions_sql = self.format_args(*actions_list) - - exists = " IF EXISTS" if expression.args.get("exists") else "" - on_cluster = self.sql(expression, "cluster") - on_cluster = f" {on_cluster}" if on_cluster else "" - only = " ONLY" if expression.args.get("only") else "" - options = self.expressions(expression, key="options") - options = f", {options}" if options else "" - kind = self.sql(expression, "kind") - not_valid = " NOT VALID" if expression.args.get("not_valid") else "" - - return f"ALTER {kind}{exists}{only} {self.sql(expression, 'this')}{on_cluster} {actions_sql}{not_valid}{options}" - - def add_column_sql(self, expression: exp.Expression) -> str: - sql = self.sql(expression) - if isinstance(expression, exp.Schema): - column_text = " COLUMNS" - elif isinstance(expression, exp.ColumnDef) and self.ALTER_TABLE_INCLUDE_COLUMN_KEYWORD: - column_text = " COLUMN" - else: - column_text = "" - - return f"ADD{column_text} {sql}" - - def droppartition_sql(self, expression: exp.DropPartition) -> str: - expressions = self.expressions(expression) - exists = " IF EXISTS " if expression.args.get("exists") else " " - return f"DROP{exists}{expressions}" - - def addconstraint_sql(self, expression: exp.AddConstraint) -> str: - return f"ADD {self.expressions(expression)}" - - def addpartition_sql(self, expression: exp.AddPartition) -> str: - exists = "IF NOT EXISTS " if expression.args.get("exists") else "" - return f"ADD {exists}{self.sql(expression.this)}" - - def distinct_sql(self, expression: exp.Distinct) -> str: - this = self.expressions(expression, flat=True) - - if not self.MULTI_ARG_DISTINCT and len(expression.expressions) > 1: - case = exp.case() - for arg in expression.expressions: - case = case.when(arg.is_(exp.null()), exp.null()) - this = self.sql(case.else_(f"({this})")) - - this = f" {this}" if this else "" - - on = self.sql(expression, "on") - on = f" ON {on}" if on else "" - return f"DISTINCT{this}{on}" - - def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str: - return self._embed_ignore_nulls(expression, "IGNORE NULLS") - - def respectnulls_sql(self, expression: exp.RespectNulls) -> str: - return self._embed_ignore_nulls(expression, "RESPECT NULLS") - - def havingmax_sql(self, expression: exp.HavingMax) -> str: - this_sql = self.sql(expression, "this") - expression_sql = self.sql(expression, "expression") - kind = "MAX" if expression.args.get("max") else "MIN" - return f"{this_sql} HAVING {kind} {expression_sql}" - - def intdiv_sql(self, expression: exp.IntDiv) -> str: - return self.sql( - exp.Cast( - this=exp.Div(this=expression.this, expression=expression.expression), - to=exp.DataType(this=exp.DataType.Type.INT), - ) - ) - - def dpipe_sql(self, expression: exp.DPipe) -> str: - if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"): - return self.func( - "CONCAT", *(exp.cast(e, exp.DataType.Type.TEXT) for e in expression.flatten()) - ) - return self.binary(expression, "||") - - def div_sql(self, expression: exp.Div) -> str: - l, r = expression.left, expression.right - - if not self.dialect.SAFE_DIVISION and expression.args.get("safe"): - r.replace(exp.Nullif(this=r.copy(), expression=exp.Literal.number(0))) - - if self.dialect.TYPED_DIVISION and not expression.args.get("typed"): - if not l.is_type(*exp.DataType.REAL_TYPES) and not r.is_type(*exp.DataType.REAL_TYPES): - l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DOUBLE)) - - elif not self.dialect.TYPED_DIVISION and expression.args.get("typed"): - if l.is_type(*exp.DataType.INTEGER_TYPES) and r.is_type(*exp.DataType.INTEGER_TYPES): - return self.sql( - exp.cast( - l / r, - to=exp.DataType.Type.BIGINT, - ) - ) - - return self.binary(expression, "/") - - def safedivide_sql(self, expression: exp.SafeDivide) -> str: - n = exp._wrap(expression.this, exp.Binary) - d = exp._wrap(expression.expression, exp.Binary) - return self.sql(exp.If(this=d.neq(0), true=n / d, false=exp.Null())) - - def overlaps_sql(self, expression: exp.Overlaps) -> str: - return self.binary(expression, "OVERLAPS") - - def distance_sql(self, expression: exp.Distance) -> str: - return self.binary(expression, "<->") - - def dot_sql(self, expression: exp.Dot) -> str: - return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}" - - def eq_sql(self, expression: exp.EQ) -> str: - return self.binary(expression, "=") - - def propertyeq_sql(self, expression: exp.PropertyEQ) -> str: - return self.binary(expression, ":=") - - def escape_sql(self, expression: exp.Escape) -> str: - return self.binary(expression, "ESCAPE") - - def glob_sql(self, expression: exp.Glob) -> str: - return self.binary(expression, "GLOB") - - def gt_sql(self, expression: exp.GT) -> str: - return self.binary(expression, ">") - - def gte_sql(self, expression: exp.GTE) -> str: - return self.binary(expression, ">=") - - def ilike_sql(self, expression: exp.ILike) -> str: - return self.binary(expression, "ILIKE") - - def ilikeany_sql(self, expression: exp.ILikeAny) -> str: - return self.binary(expression, "ILIKE ANY") - - def is_sql(self, expression: exp.Is) -> str: - if not self.IS_BOOL_ALLOWED and isinstance(expression.expression, exp.Boolean): - return self.sql( - expression.this if expression.expression.this else exp.not_(expression.this) - ) - return self.binary(expression, "IS") - - def like_sql(self, expression: exp.Like) -> str: - return self.binary(expression, "LIKE") - - def likeany_sql(self, expression: exp.LikeAny) -> str: - return self.binary(expression, "LIKE ANY") - - def similarto_sql(self, expression: exp.SimilarTo) -> str: - return self.binary(expression, "SIMILAR TO") - - def lt_sql(self, expression: exp.LT) -> str: - return self.binary(expression, "<") - - def lte_sql(self, expression: exp.LTE) -> str: - return self.binary(expression, "<=") - - def mod_sql(self, expression: exp.Mod) -> str: - return self.binary(expression, "%") - - def mul_sql(self, expression: exp.Mul) -> str: - return self.binary(expression, "*") - - def neq_sql(self, expression: exp.NEQ) -> str: - return self.binary(expression, "<>") - - def nullsafeeq_sql(self, expression: exp.NullSafeEQ) -> str: - return self.binary(expression, "IS NOT DISTINCT FROM") - - def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str: - return self.binary(expression, "IS DISTINCT FROM") - - def slice_sql(self, expression: exp.Slice) -> str: - return self.binary(expression, ":") - - def sub_sql(self, expression: exp.Sub) -> str: - return self.binary(expression, "-") - - def trycast_sql(self, expression: exp.TryCast) -> str: - return self.cast_sql(expression, safe_prefix="TRY_") - - def jsoncast_sql(self, expression: exp.JSONCast) -> str: - return self.cast_sql(expression) - - def try_sql(self, expression: exp.Try) -> str: - if not self.TRY_SUPPORTED: - self.unsupported("Unsupported TRY function") - return self.sql(expression, "this") - - return self.func("TRY", expression.this) - - def log_sql(self, expression: exp.Log) -> str: - this = expression.this - expr = expression.expression - - if self.dialect.LOG_BASE_FIRST is False: - this, expr = expr, this - elif self.dialect.LOG_BASE_FIRST is None and expr: - if this.name in ("2", "10"): - return self.func(f"LOG{this.name}", expr) - - self.unsupported(f"Unsupported logarithm with base {self.sql(this)}") - - return self.func("LOG", this, expr) - - def use_sql(self, expression: exp.Use) -> str: - kind = self.sql(expression, "kind") - kind = f" {kind}" if kind else "" - this = self.sql(expression, "this") or self.expressions(expression, flat=True) - this = f" {this}" if this else "" - return f"USE{kind}{this}" - - def binary(self, expression: exp.Binary, op: str) -> str: - sqls: t.List[str] = [] - stack: t.List[t.Union[str, exp.Expression]] = [expression] - binary_type = type(expression) - - while stack: - node = stack.pop() - - if type(node) is binary_type: - op_func = node.args.get("operator") - if op_func: - op = f"OPERATOR({self.sql(op_func)})" - - stack.append(node.right) - stack.append(f" {self.maybe_comment(op, comments=node.comments)} ") - stack.append(node.left) - else: - sqls.append(self.sql(node)) - - return "".join(sqls) - - def ceil_floor(self, expression: exp.Ceil | exp.Floor) -> str: - to_clause = self.sql(expression, "to") - if to_clause: - return f"{expression.sql_name()}({self.sql(expression, 'this')} TO {to_clause})" - - return self.function_fallback_sql(expression) - - def function_fallback_sql(self, expression: exp.Func) -> str: - args = [] - - for key in expression.arg_types: - arg_value = expression.args.get(key) - - if isinstance(arg_value, list): - for value in arg_value: - args.append(value) - elif arg_value is not None: - args.append(arg_value) - - if self.dialect.PRESERVE_ORIGINAL_NAMES: - name = (expression._meta and expression.meta.get("name")) or expression.sql_name() - else: - name = expression.sql_name() - - return self.func(name, *args) - - def func( - self, - name: str, - *args: t.Optional[exp.Expression | str], - prefix: str = "(", - suffix: str = ")", - normalize: bool = True, - ) -> str: - name = self.normalize_func(name) if normalize else name - return f"{name}{prefix}{self.format_args(*args)}{suffix}" - - def format_args(self, *args: t.Optional[str | exp.Expression], sep: str = ", ") -> str: - arg_sqls = tuple( - self.sql(arg) for arg in args if arg is not None and not isinstance(arg, bool) - ) - if self.pretty and self.too_wide(arg_sqls): - return self.indent( - "\n" + f"{sep.strip()}\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True - ) - return sep.join(arg_sqls) - - def too_wide(self, args: t.Iterable) -> bool: - return sum(len(arg) for arg in args) > self.max_text_width - - def format_time( - self, - expression: exp.Expression, - inverse_time_mapping: t.Optional[t.Dict[str, str]] = None, - inverse_time_trie: t.Optional[t.Dict] = None, - ) -> t.Optional[str]: - return format_time( - self.sql(expression, "format"), - inverse_time_mapping or self.dialect.INVERSE_TIME_MAPPING, - inverse_time_trie or self.dialect.INVERSE_TIME_TRIE, - ) - - def expressions( - self, - expression: t.Optional[exp.Expression] = None, - key: t.Optional[str] = None, - sqls: t.Optional[t.Collection[str | exp.Expression]] = None, - flat: bool = False, - indent: bool = True, - skip_first: bool = False, - skip_last: bool = False, - sep: str = ", ", - prefix: str = "", - dynamic: bool = False, - new_line: bool = False, - ) -> str: - expressions = expression.args.get(key or "expressions") if expression else sqls - - if not expressions: - return "" - - if flat: - return sep.join(sql for sql in (self.sql(e) for e in expressions) if sql) - - num_sqls = len(expressions) - result_sqls = [] - - for i, e in enumerate(expressions): - sql = self.sql(e, comment=False) - if not sql: - continue - - comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else "" - - if self.pretty: - if self.leading_comma: - result_sqls.append(f"{sep if i > 0 else ''}{prefix}{sql}{comments}") - else: - result_sqls.append( - f"{prefix}{sql}{(sep.rstrip() if comments else sep) if i + 1 < num_sqls else ''}{comments}" - ) - else: - result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}") - - if self.pretty and (not dynamic or self.too_wide(result_sqls)): - if new_line: - result_sqls.insert(0, "") - result_sqls.append("") - result_sql = "\n".join(s.rstrip() for s in result_sqls) - else: - result_sql = "".join(result_sqls) - - return ( - self.indent(result_sql, skip_first=skip_first, skip_last=skip_last) - if indent - else result_sql - ) - - def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str: - flat = flat or isinstance(expression.parent, exp.Properties) - expressions_sql = self.expressions(expression, flat=flat) - if flat: - return f"{op} {expressions_sql}" - return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" - - def naked_property(self, expression: exp.Property) -> str: - property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__) - if not property_name: - self.unsupported(f"Unsupported property {expression.__class__.__name__}") - return f"{property_name} {self.sql(expression, 'this')}" - - def tag_sql(self, expression: exp.Tag) -> str: - return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}" - - def token_sql(self, token_type: TokenType) -> str: - return self.TOKEN_MAPPING.get(token_type, token_type.name) - - def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: - this = self.sql(expression, "this") - expressions = self.no_identify(self.expressions, expression) - expressions = ( - self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}" - ) - return f"{this}{expressions}" if expressions.strip() != "" else this - - def joinhint_sql(self, expression: exp.JoinHint) -> str: - this = self.sql(expression, "this") - expressions = self.expressions(expression, flat=True) - return f"{this}({expressions})" - - def kwarg_sql(self, expression: exp.Kwarg) -> str: - return self.binary(expression, "=>") - - def when_sql(self, expression: exp.When) -> str: - matched = "MATCHED" if expression.args["matched"] else "NOT MATCHED" - source = " BY SOURCE" if self.MATCHED_BY_SOURCE and expression.args.get("source") else "" - condition = self.sql(expression, "condition") - condition = f" AND {condition}" if condition else "" - - then_expression = expression.args.get("then") - if isinstance(then_expression, exp.Insert): - this = self.sql(then_expression, "this") - this = f"INSERT {this}" if this else "INSERT" - then = self.sql(then_expression, "expression") - then = f"{this} VALUES {then}" if then else this - elif isinstance(then_expression, exp.Update): - if isinstance(then_expression.args.get("expressions"), exp.Star): - then = f"UPDATE {self.sql(then_expression, 'expressions')}" - else: - then = f"UPDATE SET{self.sep()}{self.expressions(then_expression)}" - else: - then = self.sql(then_expression) - return f"WHEN {matched}{source}{condition} THEN {then}" - - def whens_sql(self, expression: exp.Whens) -> str: - return self.expressions(expression, sep=" ", indent=False) - - def merge_sql(self, expression: exp.Merge) -> str: - table = expression.this - table_alias = "" - - hints = table.args.get("hints") - if hints and table.alias and isinstance(hints[0], exp.WithTableHint): - # T-SQL syntax is MERGE ... [WITH ()] [[AS] table_alias] - table_alias = f" AS {self.sql(table.args['alias'].pop())}" - - this = self.sql(table) - using = f"USING {self.sql(expression, 'using')}" - on = f"ON {self.sql(expression, 'on')}" - whens = self.sql(expression, "whens") - - returning = self.sql(expression, "returning") - if returning: - whens = f"{whens}{returning}" - - sep = self.sep() - - return self.prepend_ctes( - expression, - f"MERGE INTO {this}{table_alias}{sep}{using}{sep}{on}{sep}{whens}", - ) - - @unsupported_args("format") - def tochar_sql(self, expression: exp.ToChar) -> str: - return self.sql(exp.cast(expression.this, exp.DataType.Type.TEXT)) - - def tonumber_sql(self, expression: exp.ToNumber) -> str: - if not self.SUPPORTS_TO_NUMBER: - self.unsupported("Unsupported TO_NUMBER function") - return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE)) - - fmt = expression.args.get("format") - if not fmt: - self.unsupported("Conversion format is required for TO_NUMBER") - return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE)) - - return self.func("TO_NUMBER", expression.this, fmt) - - def dictproperty_sql(self, expression: exp.DictProperty) -> str: - this = self.sql(expression, "this") - kind = self.sql(expression, "kind") - settings_sql = self.expressions(expression, key="settings", sep=" ") - args = f"({self.sep('')}{settings_sql}{self.seg(')', sep='')}" if settings_sql else "()" - return f"{this}({kind}{args})" - - def dictrange_sql(self, expression: exp.DictRange) -> str: - this = self.sql(expression, "this") - max = self.sql(expression, "max") - min = self.sql(expression, "min") - return f"{this}(MIN {min} MAX {max})" - - def dictsubproperty_sql(self, expression: exp.DictSubProperty) -> str: - return f"{self.sql(expression, 'this')} {self.sql(expression, 'value')}" - - def duplicatekeyproperty_sql(self, expression: exp.DuplicateKeyProperty) -> str: - return f"DUPLICATE KEY ({self.expressions(expression, flat=True)})" - - # https://docs.starrocks.io/docs/sql-reference/sql-statements/table_bucket_part_index/CREATE_TABLE/ - def uniquekeyproperty_sql(self, expression: exp.UniqueKeyProperty) -> str: - return f"UNIQUE KEY ({self.expressions(expression, flat=True)})" - - # https://docs.starrocks.io/docs/sql-reference/sql-statements/data-definition/CREATE_TABLE/#distribution_desc - def distributedbyproperty_sql(self, expression: exp.DistributedByProperty) -> str: - expressions = self.expressions(expression, flat=True) - expressions = f" {self.wrap(expressions)}" if expressions else "" - buckets = self.sql(expression, "buckets") - kind = self.sql(expression, "kind") - buckets = f" BUCKETS {buckets}" if buckets else "" - order = self.sql(expression, "order") - return f"DISTRIBUTED BY {kind}{expressions}{buckets}{order}" - - def oncluster_sql(self, expression: exp.OnCluster) -> str: - return "" - - def clusteredbyproperty_sql(self, expression: exp.ClusteredByProperty) -> str: - expressions = self.expressions(expression, key="expressions", flat=True) - sorted_by = self.expressions(expression, key="sorted_by", flat=True) - sorted_by = f" SORTED BY ({sorted_by})" if sorted_by else "" - buckets = self.sql(expression, "buckets") - return f"CLUSTERED BY ({expressions}){sorted_by} INTO {buckets} BUCKETS" - - def anyvalue_sql(self, expression: exp.AnyValue) -> str: - this = self.sql(expression, "this") - having = self.sql(expression, "having") - - if having: - this = f"{this} HAVING {'MAX' if expression.args.get('max') else 'MIN'} {having}" - - return self.func("ANY_VALUE", this) - - def querytransform_sql(self, expression: exp.QueryTransform) -> str: - transform = self.func("TRANSFORM", *expression.expressions) - row_format_before = self.sql(expression, "row_format_before") - row_format_before = f" {row_format_before}" if row_format_before else "" - record_writer = self.sql(expression, "record_writer") - record_writer = f" RECORDWRITER {record_writer}" if record_writer else "" - using = f" USING {self.sql(expression, 'command_script')}" - schema = self.sql(expression, "schema") - schema = f" AS {schema}" if schema else "" - row_format_after = self.sql(expression, "row_format_after") - row_format_after = f" {row_format_after}" if row_format_after else "" - record_reader = self.sql(expression, "record_reader") - record_reader = f" RECORDREADER {record_reader}" if record_reader else "" - return f"{transform}{row_format_before}{record_writer}{using}{schema}{row_format_after}{record_reader}" - - def indexconstraintoption_sql(self, expression: exp.IndexConstraintOption) -> str: - key_block_size = self.sql(expression, "key_block_size") - if key_block_size: - return f"KEY_BLOCK_SIZE = {key_block_size}" - - using = self.sql(expression, "using") - if using: - return f"USING {using}" - - parser = self.sql(expression, "parser") - if parser: - return f"WITH PARSER {parser}" - - comment = self.sql(expression, "comment") - if comment: - return f"COMMENT {comment}" - - visible = expression.args.get("visible") - if visible is not None: - return "VISIBLE" if visible else "INVISIBLE" - - engine_attr = self.sql(expression, "engine_attr") - if engine_attr: - return f"ENGINE_ATTRIBUTE = {engine_attr}" - - secondary_engine_attr = self.sql(expression, "secondary_engine_attr") - if secondary_engine_attr: - return f"SECONDARY_ENGINE_ATTRIBUTE = {secondary_engine_attr}" - - self.unsupported("Unsupported index constraint option.") - return "" - - def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str: - enforced = " ENFORCED" if expression.args.get("enforced") else "" - return f"CHECK ({self.sql(expression, 'this')}){enforced}" - - def indexcolumnconstraint_sql(self, expression: exp.IndexColumnConstraint) -> str: - kind = self.sql(expression, "kind") - kind = f"{kind} INDEX" if kind else "INDEX" - this = self.sql(expression, "this") - this = f" {this}" if this else "" - index_type = self.sql(expression, "index_type") - index_type = f" USING {index_type}" if index_type else "" - expressions = self.expressions(expression, flat=True) - expressions = f" ({expressions})" if expressions else "" - options = self.expressions(expression, key="options", sep=" ") - options = f" {options}" if options else "" - return f"{kind}{this}{index_type}{expressions}{options}" - - def nvl2_sql(self, expression: exp.Nvl2) -> str: - if self.NVL2_SUPPORTED: - return self.function_fallback_sql(expression) - - case = exp.Case().when( - expression.this.is_(exp.null()).not_(copy=False), - expression.args["true"], - copy=False, - ) - else_cond = expression.args.get("false") - if else_cond: - case.else_(else_cond, copy=False) - - return self.sql(case) - - def comprehension_sql(self, expression: exp.Comprehension) -> str: - this = self.sql(expression, "this") - expr = self.sql(expression, "expression") - iterator = self.sql(expression, "iterator") - condition = self.sql(expression, "condition") - condition = f" IF {condition}" if condition else "" - return f"{this} FOR {expr} IN {iterator}{condition}" - - def columnprefix_sql(self, expression: exp.ColumnPrefix) -> str: - return f"{self.sql(expression, 'this')}({self.sql(expression, 'expression')})" - - def opclass_sql(self, expression: exp.Opclass) -> str: - return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" - - def predict_sql(self, expression: exp.Predict) -> str: - model = self.sql(expression, "this") - model = f"MODEL {model}" - table = self.sql(expression, "expression") - table = f"TABLE {table}" if not isinstance(expression.expression, exp.Subquery) else table - parameters = self.sql(expression, "params_struct") - return self.func("PREDICT", model, table, parameters or None) - - def forin_sql(self, expression: exp.ForIn) -> str: - this = self.sql(expression, "this") - expression_sql = self.sql(expression, "expression") - return f"FOR {this} DO {expression_sql}" - - def refresh_sql(self, expression: exp.Refresh) -> str: - this = self.sql(expression, "this") - table = "" if isinstance(expression.this, exp.Literal) else "TABLE " - return f"REFRESH {table}{this}" - - def toarray_sql(self, expression: exp.ToArray) -> str: - arg = expression.this - if not arg.type: - from sqlglot.optimizer.annotate_types import annotate_types - - arg = annotate_types(arg, dialect=self.dialect) - - if arg.is_type(exp.DataType.Type.ARRAY): - return self.sql(arg) - - cond_for_null = arg.is_(exp.null()) - return self.sql(exp.func("IF", cond_for_null, exp.null(), exp.array(arg, copy=False))) - - def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str: - this = expression.this - time_format = self.format_time(expression) - - if time_format: - return self.sql( - exp.cast( - exp.StrToTime(this=this, format=expression.args["format"]), - exp.DataType.Type.TIME, - ) - ) - - if isinstance(this, exp.TsOrDsToTime) or this.is_type(exp.DataType.Type.TIME): - return self.sql(this) - - return self.sql(exp.cast(this, exp.DataType.Type.TIME)) - - def tsordstotimestamp_sql(self, expression: exp.TsOrDsToTimestamp) -> str: - this = expression.this - if isinstance(this, exp.TsOrDsToTimestamp) or this.is_type(exp.DataType.Type.TIMESTAMP): - return self.sql(this) - - return self.sql(exp.cast(this, exp.DataType.Type.TIMESTAMP, dialect=self.dialect)) - - def tsordstodatetime_sql(self, expression: exp.TsOrDsToDatetime) -> str: - this = expression.this - if isinstance(this, exp.TsOrDsToDatetime) or this.is_type(exp.DataType.Type.DATETIME): - return self.sql(this) - - return self.sql(exp.cast(this, exp.DataType.Type.DATETIME, dialect=self.dialect)) - - def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str: - this = expression.this - time_format = self.format_time(expression) - - if time_format and time_format not in (self.dialect.TIME_FORMAT, self.dialect.DATE_FORMAT): - return self.sql( - exp.cast( - exp.StrToTime(this=this, format=expression.args["format"]), - exp.DataType.Type.DATE, - ) - ) - - if isinstance(this, exp.TsOrDsToDate) or this.is_type(exp.DataType.Type.DATE): - return self.sql(this) - - return self.sql(exp.cast(this, exp.DataType.Type.DATE)) - - def unixdate_sql(self, expression: exp.UnixDate) -> str: - return self.sql( - exp.func( - "DATEDIFF", - expression.this, - exp.cast(exp.Literal.string("1970-01-01"), exp.DataType.Type.DATE), - "day", - ) - ) - - def lastday_sql(self, expression: exp.LastDay) -> str: - if self.LAST_DAY_SUPPORTS_DATE_PART: - return self.function_fallback_sql(expression) - - unit = expression.text("unit") - if unit and unit != "MONTH": - self.unsupported("Date parts are not supported in LAST_DAY.") - - return self.func("LAST_DAY", expression.this) - - def dateadd_sql(self, expression: exp.DateAdd) -> str: - from sqlglot.dialects.dialect import unit_to_str - - return self.func( - "DATE_ADD", expression.this, expression.expression, unit_to_str(expression) - ) - - def arrayany_sql(self, expression: exp.ArrayAny) -> str: - if self.CAN_IMPLEMENT_ARRAY_ANY: - filtered = exp.ArrayFilter(this=expression.this, expression=expression.expression) - filtered_not_empty = exp.ArraySize(this=filtered).neq(0) - original_is_empty = exp.ArraySize(this=expression.this).eq(0) - return self.sql(exp.paren(original_is_empty.or_(filtered_not_empty))) - - from sqlglot.dialects import Dialect - - # SQLGlot's executor supports ARRAY_ANY, so we don't wanna warn for the SQLGlot dialect - if self.dialect.__class__ != Dialect: - self.unsupported("ARRAY_ANY is unsupported") - - return self.function_fallback_sql(expression) - - def struct_sql(self, expression: exp.Struct) -> str: - expression.set( - "expressions", - [ - exp.alias_(e.expression, e.name if e.this.is_string else e.this) - if isinstance(e, exp.PropertyEQ) - else e - for e in expression.expressions - ], - ) - - return self.function_fallback_sql(expression) - - def partitionrange_sql(self, expression: exp.PartitionRange) -> str: - low = self.sql(expression, "this") - high = self.sql(expression, "expression") - - return f"{low} TO {high}" - - def truncatetable_sql(self, expression: exp.TruncateTable) -> str: - target = "DATABASE" if expression.args.get("is_database") else "TABLE" - tables = f" {self.expressions(expression)}" - - exists = " IF EXISTS" if expression.args.get("exists") else "" - - on_cluster = self.sql(expression, "cluster") - on_cluster = f" {on_cluster}" if on_cluster else "" - - identity = self.sql(expression, "identity") - identity = f" {identity} IDENTITY" if identity else "" - - option = self.sql(expression, "option") - option = f" {option}" if option else "" - - partition = self.sql(expression, "partition") - partition = f" {partition}" if partition else "" - - return f"TRUNCATE {target}{exists}{tables}{on_cluster}{identity}{option}{partition}" - - # This transpiles T-SQL's CONVERT function - # https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16 - def convert_sql(self, expression: exp.Convert) -> str: - to = expression.this - value = expression.expression - style = expression.args.get("style") - safe = expression.args.get("safe") - strict = expression.args.get("strict") - - if not to or not value: - return "" - - # Retrieve length of datatype and override to default if not specified - if not seq_get(to.expressions, 0) and to.this in self.PARAMETERIZABLE_TEXT_TYPES: - to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False) - - transformed: t.Optional[exp.Expression] = None - cast = exp.Cast if strict else exp.TryCast - - # Check whether a conversion with format (T-SQL calls this 'style') is applicable - if isinstance(style, exp.Literal) and style.is_int: - from sqlglot.dialects.tsql import TSQL - - style_value = style.name - converted_style = TSQL.CONVERT_FORMAT_MAPPING.get(style_value) - if not converted_style: - self.unsupported(f"Unsupported T-SQL 'style' value: {style_value}") - - fmt = exp.Literal.string(converted_style) - - if to.this == exp.DataType.Type.DATE: - transformed = exp.StrToDate(this=value, format=fmt) - elif to.this in (exp.DataType.Type.DATETIME, exp.DataType.Type.DATETIME2): - transformed = exp.StrToTime(this=value, format=fmt) - elif to.this in self.PARAMETERIZABLE_TEXT_TYPES: - transformed = cast(this=exp.TimeToStr(this=value, format=fmt), to=to, safe=safe) - elif to.this == exp.DataType.Type.TEXT: - transformed = exp.TimeToStr(this=value, format=fmt) - - if not transformed: - transformed = cast(this=value, to=to, safe=safe) - - return self.sql(transformed) - - def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str: - this = expression.this - if isinstance(this, exp.JSONPathWildcard): - this = self.json_path_part(this) - return f".{this}" if this else "" - - if exp.SAFE_IDENTIFIER_RE.match(this): - return f".{this}" - - this = self.json_path_part(this) - return ( - f"[{this}]" - if self._quote_json_path_key_using_brackets and self.JSON_PATH_BRACKETED_KEY_SUPPORTED - else f".{this}" - ) - - def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str: - this = self.json_path_part(expression.this) - return f"[{this}]" if this else "" - - def _simplify_unless_literal(self, expression: E) -> E: - if not isinstance(expression, exp.Literal): - from sqlglot.optimizer.simplify import simplify - - expression = simplify(expression, dialect=self.dialect) - - return expression - - def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str: - this = expression.this - if isinstance(this, self.RESPECT_IGNORE_NULLS_UNSUPPORTED_EXPRESSIONS): - self.unsupported( - f"RESPECT/IGNORE NULLS is not supported for {type(this).key} in {self.dialect.__class__.__name__}" - ) - return self.sql(this) - - if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"): - # The first modifier here will be the one closest to the AggFunc's arg - mods = sorted( - expression.find_all(exp.HavingMax, exp.Order, exp.Limit), - key=lambda x: 0 - if isinstance(x, exp.HavingMax) - else (1 if isinstance(x, exp.Order) else 2), - ) - - if mods: - mod = mods[0] - this = expression.__class__(this=mod.this.copy()) - this.meta["inline"] = True - mod.this.replace(this) - return self.sql(expression.this) - - agg_func = expression.find(exp.AggFunc) - - if agg_func: - agg_func_sql = self.sql(agg_func, comment=False)[:-1] + f" {text})" - return self.maybe_comment(agg_func_sql, comments=agg_func.comments) - - return f"{self.sql(expression, 'this')} {text}" - - def _replace_line_breaks(self, string: str) -> str: - """We don't want to extra indent line breaks so we temporarily replace them with sentinels.""" - if self.pretty: - return string.replace("\n", self.SENTINEL_LINE_BREAK) - return string - - def copyparameter_sql(self, expression: exp.CopyParameter) -> str: - option = self.sql(expression, "this") - - if expression.expressions: - upper = option.upper() - - # Snowflake FILE_FORMAT options are separated by whitespace - sep = " " if upper == "FILE_FORMAT" else ", " - - # Databricks copy/format options do not set their list of values with EQ - op = " " if upper in ("COPY_OPTIONS", "FORMAT_OPTIONS") else " = " - values = self.expressions(expression, flat=True, sep=sep) - return f"{option}{op}({values})" - - value = self.sql(expression, "expression") - - if not value: - return option - - op = " = " if self.COPY_PARAMS_EQ_REQUIRED else " " - - return f"{option}{op}{value}" - - def credentials_sql(self, expression: exp.Credentials) -> str: - cred_expr = expression.args.get("credentials") - if isinstance(cred_expr, exp.Literal): - # Redshift case: CREDENTIALS - credentials = self.sql(expression, "credentials") - credentials = f"CREDENTIALS {credentials}" if credentials else "" - else: - # Snowflake case: CREDENTIALS = (...) - credentials = self.expressions(expression, key="credentials", flat=True, sep=" ") - credentials = f"CREDENTIALS = ({credentials})" if cred_expr is not None else "" - - storage = self.sql(expression, "storage") - storage = f"STORAGE_INTEGRATION = {storage}" if storage else "" - - encryption = self.expressions(expression, key="encryption", flat=True, sep=" ") - encryption = f" ENCRYPTION = ({encryption})" if encryption else "" - - iam_role = self.sql(expression, "iam_role") - iam_role = f"IAM_ROLE {iam_role}" if iam_role else "" - - region = self.sql(expression, "region") - region = f" REGION {region}" if region else "" - - return f"{credentials}{storage}{encryption}{iam_role}{region}" - - def copy_sql(self, expression: exp.Copy) -> str: - this = self.sql(expression, "this") - this = f" INTO {this}" if self.COPY_HAS_INTO_KEYWORD else f" {this}" - - credentials = self.sql(expression, "credentials") - credentials = self.seg(credentials) if credentials else "" - kind = self.seg("FROM" if expression.args.get("kind") else "TO") - files = self.expressions(expression, key="files", flat=True) - - sep = ", " if self.dialect.COPY_PARAMS_ARE_CSV else " " - params = self.expressions( - expression, - key="params", - sep=sep, - new_line=True, - skip_last=True, - skip_first=True, - indent=self.COPY_PARAMS_ARE_WRAPPED, - ) - - if params: - if self.COPY_PARAMS_ARE_WRAPPED: - params = f" WITH ({params})" - elif not self.pretty: - params = f" {params}" - - return f"COPY{this}{kind} {files}{credentials}{params}" - - def semicolon_sql(self, expression: exp.Semicolon) -> str: - return "" - - def datadeletionproperty_sql(self, expression: exp.DataDeletionProperty) -> str: - on_sql = "ON" if expression.args.get("on") else "OFF" - filter_col: t.Optional[str] = self.sql(expression, "filter_column") - filter_col = f"FILTER_COLUMN={filter_col}" if filter_col else None - retention_period: t.Optional[str] = self.sql(expression, "retention_period") - retention_period = f"RETENTION_PERIOD={retention_period}" if retention_period else None - - if filter_col or retention_period: - on_sql = self.func("ON", filter_col, retention_period) - - return f"DATA_DELETION={on_sql}" - - def maskingpolicycolumnconstraint_sql( - self, expression: exp.MaskingPolicyColumnConstraint - ) -> str: - this = self.sql(expression, "this") - expressions = self.expressions(expression, flat=True) - expressions = f" USING ({expressions})" if expressions else "" - return f"MASKING POLICY {this}{expressions}" - - def gapfill_sql(self, expression: exp.GapFill) -> str: - this = self.sql(expression, "this") - this = f"TABLE {this}" - return self.func("GAP_FILL", this, *[v for k, v in expression.args.items() if k != "this"]) - - def scope_resolution(self, rhs: str, scope_name: str) -> str: - return self.func("SCOPE_RESOLUTION", scope_name or None, rhs) - - def scoperesolution_sql(self, expression: exp.ScopeResolution) -> str: - this = self.sql(expression, "this") - expr = expression.expression - - if isinstance(expr, exp.Func): - # T-SQL's CLR functions are case sensitive - expr = f"{self.sql(expr, 'this')}({self.format_args(*expr.expressions)})" - else: - expr = self.sql(expression, "expression") - - return self.scope_resolution(expr, this) - - def parsejson_sql(self, expression: exp.ParseJSON) -> str: - if self.PARSE_JSON_NAME is None: - return self.sql(expression.this) - - return self.func(self.PARSE_JSON_NAME, expression.this, expression.expression) - - def rand_sql(self, expression: exp.Rand) -> str: - lower = self.sql(expression, "lower") - upper = self.sql(expression, "upper") - - if lower and upper: - return f"({upper} - {lower}) * {self.func('RAND', expression.this)} + {lower}" - return self.func("RAND", expression.this) - - def changes_sql(self, expression: exp.Changes) -> str: - information = self.sql(expression, "information") - information = f"INFORMATION => {information}" - at_before = self.sql(expression, "at_before") - at_before = f"{self.seg('')}{at_before}" if at_before else "" - end = self.sql(expression, "end") - end = f"{self.seg('')}{end}" if end else "" - - return f"CHANGES ({information}){at_before}{end}" - - def pad_sql(self, expression: exp.Pad) -> str: - prefix = "L" if expression.args.get("is_left") else "R" - - fill_pattern = self.sql(expression, "fill_pattern") or None - if not fill_pattern and self.PAD_FILL_PATTERN_IS_REQUIRED: - fill_pattern = "' '" - - return self.func(f"{prefix}PAD", expression.this, expression.expression, fill_pattern) - - def summarize_sql(self, expression: exp.Summarize) -> str: - table = " TABLE" if expression.args.get("table") else "" - return f"SUMMARIZE{table} {self.sql(expression.this)}" - - def explodinggenerateseries_sql(self, expression: exp.ExplodingGenerateSeries) -> str: - generate_series = exp.GenerateSeries(**expression.args) - - parent = expression.parent - if isinstance(parent, (exp.Alias, exp.TableAlias)): - parent = parent.parent - - if self.SUPPORTS_EXPLODING_PROJECTIONS and not isinstance(parent, (exp.Table, exp.Unnest)): - return self.sql(exp.Unnest(expressions=[generate_series])) - - if isinstance(parent, exp.Select): - self.unsupported("GenerateSeries projection unnesting is not supported.") - - return self.sql(generate_series) - - def arrayconcat_sql(self, expression: exp.ArrayConcat, name: str = "ARRAY_CONCAT") -> str: - exprs = expression.expressions - if not self.ARRAY_CONCAT_IS_VAR_LEN: - rhs = reduce(lambda x, y: exp.ArrayConcat(this=x, expressions=[y]), exprs) - else: - rhs = self.expressions(expression) - - return self.func(name, expression.this, rhs or None) - - def converttimezone_sql(self, expression: exp.ConvertTimezone) -> str: - if self.SUPPORTS_CONVERT_TIMEZONE: - return self.function_fallback_sql(expression) - - source_tz = expression.args.get("source_tz") - target_tz = expression.args.get("target_tz") - timestamp = expression.args.get("timestamp") - - if source_tz and timestamp: - timestamp = exp.AtTimeZone( - this=exp.cast(timestamp, exp.DataType.Type.TIMESTAMPNTZ), zone=source_tz - ) - - expr = exp.AtTimeZone(this=timestamp, zone=target_tz) - - return self.sql(expr) - - def json_sql(self, expression: exp.JSON) -> str: - this = self.sql(expression, "this") - this = f" {this}" if this else "" - - _with = expression.args.get("with") - - if _with is None: - with_sql = "" - elif not _with: - with_sql = " WITHOUT" - else: - with_sql = " WITH" - - unique_sql = " UNIQUE KEYS" if expression.args.get("unique") else "" - - return f"JSON{this}{with_sql}{unique_sql}" - - def jsonvalue_sql(self, expression: exp.JSONValue) -> str: - def _generate_on_options(arg: t.Any) -> str: - return arg if isinstance(arg, str) else f"DEFAULT {self.sql(arg)}" - - path = self.sql(expression, "path") - returning = self.sql(expression, "returning") - returning = f" RETURNING {returning}" if returning else "" - - on_condition = self.sql(expression, "on_condition") - on_condition = f" {on_condition}" if on_condition else "" - - return self.func("JSON_VALUE", expression.this, f"{path}{returning}{on_condition}") - - def conditionalinsert_sql(self, expression: exp.ConditionalInsert) -> str: - else_ = "ELSE " if expression.args.get("else_") else "" - condition = self.sql(expression, "expression") - condition = f"WHEN {condition} THEN " if condition else else_ - insert = self.sql(expression, "this")[len("INSERT") :].strip() - return f"{condition}{insert}" - - def multitableinserts_sql(self, expression: exp.MultitableInserts) -> str: - kind = self.sql(expression, "kind") - expressions = self.seg(self.expressions(expression, sep=" ")) - res = f"INSERT {kind}{expressions}{self.seg(self.sql(expression, 'source'))}" - return res - - def oncondition_sql(self, expression: exp.OnCondition) -> str: - # Static options like "NULL ON ERROR" are stored as strings, in contrast to "DEFAULT ON ERROR" - empty = expression.args.get("empty") - empty = ( - f"DEFAULT {empty} ON EMPTY" - if isinstance(empty, exp.Expression) - else self.sql(expression, "empty") - ) - - error = expression.args.get("error") - error = ( - f"DEFAULT {error} ON ERROR" - if isinstance(error, exp.Expression) - else self.sql(expression, "error") - ) - - if error and empty: - error = ( - f"{empty} {error}" - if self.dialect.ON_CONDITION_EMPTY_BEFORE_ERROR - else f"{error} {empty}" - ) - empty = "" - - null = self.sql(expression, "null") - - return f"{empty}{error}{null}" - - def jsonextractquote_sql(self, expression: exp.JSONExtractQuote) -> str: - scalar = " ON SCALAR STRING" if expression.args.get("scalar") else "" - return f"{self.sql(expression, 'option')} QUOTES{scalar}" - - def jsonexists_sql(self, expression: exp.JSONExists) -> str: - this = self.sql(expression, "this") - path = self.sql(expression, "path") - - passing = self.expressions(expression, "passing") - passing = f" PASSING {passing}" if passing else "" - - on_condition = self.sql(expression, "on_condition") - on_condition = f" {on_condition}" if on_condition else "" - - path = f"{path}{passing}{on_condition}" - - return self.func("JSON_EXISTS", this, path) - - def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: - array_agg = self.function_fallback_sql(expression) - - # Add a NULL FILTER on the column to mimic the results going from a dialect that excludes nulls - # on ARRAY_AGG (e.g Spark) to one that doesn't (e.g. DuckDB) - if self.dialect.ARRAY_AGG_INCLUDES_NULLS and expression.args.get("nulls_excluded"): - parent = expression.parent - if isinstance(parent, exp.Filter): - parent_cond = parent.expression.this - parent_cond.replace(parent_cond.and_(expression.this.is_(exp.null()).not_())) - else: - this = expression.this - # Do not add the filter if the input is not a column (e.g. literal, struct etc) - if this.find(exp.Column): - # DISTINCT is already present in the agg function, do not propagate it to FILTER as well - this_sql = ( - self.expressions(this) - if isinstance(this, exp.Distinct) - else self.sql(expression, "this") - ) - - array_agg = f"{array_agg} FILTER(WHERE {this_sql} IS NOT NULL)" - - return array_agg - - def apply_sql(self, expression: exp.Apply) -> str: - this = self.sql(expression, "this") - expr = self.sql(expression, "expression") - - return f"{this} APPLY({expr})" - - def grant_sql(self, expression: exp.Grant) -> str: - privileges_sql = self.expressions(expression, key="privileges", flat=True) - - kind = self.sql(expression, "kind") - kind = f" {kind}" if kind else "" - - securable = self.sql(expression, "securable") - securable = f" {securable}" if securable else "" - - principals = self.expressions(expression, key="principals", flat=True) - - grant_option = " WITH GRANT OPTION" if expression.args.get("grant_option") else "" - - return f"GRANT {privileges_sql} ON{kind}{securable} TO {principals}{grant_option}" - - def grantprivilege_sql(self, expression: exp.GrantPrivilege): - this = self.sql(expression, "this") - columns = self.expressions(expression, flat=True) - columns = f"({columns})" if columns else "" - - return f"{this}{columns}" - - def grantprincipal_sql(self, expression: exp.GrantPrincipal): - this = self.sql(expression, "this") - - kind = self.sql(expression, "kind") - kind = f"{kind} " if kind else "" - - return f"{kind}{this}" - - def columns_sql(self, expression: exp.Columns): - func = self.function_fallback_sql(expression) - if expression.args.get("unpack"): - func = f"*{func}" - - return func - - def overlay_sql(self, expression: exp.Overlay): - this = self.sql(expression, "this") - expr = self.sql(expression, "expression") - from_sql = self.sql(expression, "from") - for_sql = self.sql(expression, "for") - for_sql = f" FOR {for_sql}" if for_sql else "" - - return f"OVERLAY({this} PLACING {expr} FROM {from_sql}{for_sql})" - - @unsupported_args("format") - def todouble_sql(self, expression: exp.ToDouble) -> str: - return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE)) - - def string_sql(self, expression: exp.String) -> str: - this = expression.this - zone = expression.args.get("zone") - - if zone: - # This is a BigQuery specific argument for STRING(, ) - # BigQuery stores timestamps internally as UTC, so ConvertTimezone is used with UTC - # set for source_tz to transpile the time conversion before the STRING cast - this = exp.ConvertTimezone( - source_tz=exp.Literal.string("UTC"), target_tz=zone, timestamp=this - ) - - return self.sql(exp.cast(this, exp.DataType.Type.VARCHAR)) - - def median_sql(self, expression: exp.Median): - if not self.SUPPORTS_MEDIAN: - return self.sql( - exp.PercentileCont(this=expression.this, expression=exp.Literal.number(0.5)) - ) - - return self.function_fallback_sql(expression) - - def overflowtruncatebehavior_sql(self, expression: exp.OverflowTruncateBehavior) -> str: - filler = self.sql(expression, "this") - filler = f" {filler}" if filler else "" - with_count = "WITH COUNT" if expression.args.get("with_count") else "WITHOUT COUNT" - return f"TRUNCATE{filler} {with_count}" - - def unixseconds_sql(self, expression: exp.UnixSeconds) -> str: - if self.SUPPORTS_UNIX_SECONDS: - return self.function_fallback_sql(expression) - - start_ts = exp.cast( - exp.Literal.string("1970-01-01 00:00:00+00"), to=exp.DataType.Type.TIMESTAMPTZ - ) - - return self.sql( - exp.TimestampDiff(this=expression.this, expression=start_ts, unit=exp.var("SECONDS")) - ) - - def arraysize_sql(self, expression: exp.ArraySize) -> str: - dim = expression.expression - - # For dialects that don't support the dimension arg, we can safely transpile it's default value (1st dimension) - if dim and self.ARRAY_SIZE_DIM_REQUIRED is None: - if not (dim.is_int and dim.name == "1"): - self.unsupported("Cannot transpile dimension argument for ARRAY_LENGTH") - dim = None - - # If dimension is required but not specified, default initialize it - if self.ARRAY_SIZE_DIM_REQUIRED and not dim: - dim = exp.Literal.number(1) - - return self.func(self.ARRAY_SIZE_NAME, expression.this, dim) - - def attach_sql(self, expression: exp.Attach) -> str: - this = self.sql(expression, "this") - exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" - expressions = self.expressions(expression) - expressions = f" ({expressions})" if expressions else "" - - return f"ATTACH{exists_sql} {this}{expressions}" - - def detach_sql(self, expression: exp.Detach) -> str: - this = self.sql(expression, "this") - exists_sql = " IF EXISTS" if expression.args.get("exists") else "" - - return f"DETACH{exists_sql} {this}" - - def attachoption_sql(self, expression: exp.AttachOption) -> str: - this = self.sql(expression, "this") - value = self.sql(expression, "expression") - value = f" {value}" if value else "" - return f"{this}{value}" - - def featuresattime_sql(self, expression: exp.FeaturesAtTime) -> str: - this_sql = self.sql(expression, "this") - if isinstance(expression.this, exp.Table): - this_sql = f"TABLE {this_sql}" - - return self.func( - "FEATURES_AT_TIME", - this_sql, - expression.args.get("time"), - expression.args.get("num_rows"), - expression.args.get("ignore_feature_nulls"), - ) - - def watermarkcolumnconstraint_sql(self, expression: exp.WatermarkColumnConstraint) -> str: - return ( - f"WATERMARK FOR {self.sql(expression, 'this')} AS {self.sql(expression, 'expression')}" - ) - - def encodeproperty_sql(self, expression: exp.EncodeProperty) -> str: - encode = "KEY ENCODE" if expression.args.get("key") else "ENCODE" - encode = f"{encode} {self.sql(expression, 'this')}" - - properties = expression.args.get("properties") - if properties: - encode = f"{encode} {self.properties(properties)}" - - return encode - - def includeproperty_sql(self, expression: exp.IncludeProperty) -> str: - this = self.sql(expression, "this") - include = f"INCLUDE {this}" - - column_def = self.sql(expression, "column_def") - if column_def: - include = f"{include} {column_def}" - - alias = self.sql(expression, "alias") - if alias: - include = f"{include} AS {alias}" - - return include - - def xmlelement_sql(self, expression: exp.XMLElement) -> str: - name = f"NAME {self.sql(expression, 'this')}" - return self.func("XMLELEMENT", name, *expression.expressions) - - def xmlkeyvalueoption_sql(self, expression: exp.XMLKeyValueOption) -> str: - this = self.sql(expression, "this") - expr = self.sql(expression, "expression") - expr = f"({expr})" if expr else "" - return f"{this}{expr}" - - def partitionbyrangeproperty_sql(self, expression: exp.PartitionByRangeProperty) -> str: - partitions = self.expressions(expression, "partition_expressions") - create = self.expressions(expression, "create_expressions") - return f"PARTITION BY RANGE {self.wrap(partitions)} {self.wrap(create)}" - - def partitionbyrangepropertydynamic_sql( - self, expression: exp.PartitionByRangePropertyDynamic - ) -> str: - start = self.sql(expression, "start") - end = self.sql(expression, "end") - - every = expression.args["every"] - if isinstance(every, exp.Interval) and every.this.is_string: - every.this.replace(exp.Literal.number(every.name)) - - return f"START {self.wrap(start)} END {self.wrap(end)} EVERY {self.wrap(self.sql(every))}" - - def unpivotcolumns_sql(self, expression: exp.UnpivotColumns) -> str: - name = self.sql(expression, "this") - values = self.expressions(expression, flat=True) - - return f"NAME {name} VALUE {values}" - - def analyzesample_sql(self, expression: exp.AnalyzeSample) -> str: - kind = self.sql(expression, "kind") - sample = self.sql(expression, "sample") - return f"SAMPLE {sample} {kind}" - - def analyzestatistics_sql(self, expression: exp.AnalyzeStatistics) -> str: - kind = self.sql(expression, "kind") - option = self.sql(expression, "option") - option = f" {option}" if option else "" - this = self.sql(expression, "this") - this = f" {this}" if this else "" - columns = self.expressions(expression) - columns = f" {columns}" if columns else "" - return f"{kind}{option} STATISTICS{this}{columns}" - - def analyzehistogram_sql(self, expression: exp.AnalyzeHistogram) -> str: - this = self.sql(expression, "this") - columns = self.expressions(expression) - inner_expression = self.sql(expression, "expression") - inner_expression = f" {inner_expression}" if inner_expression else "" - update_options = self.sql(expression, "update_options") - update_options = f" {update_options} UPDATE" if update_options else "" - return f"{this} HISTOGRAM ON {columns}{inner_expression}{update_options}" - - def analyzedelete_sql(self, expression: exp.AnalyzeDelete) -> str: - kind = self.sql(expression, "kind") - kind = f" {kind}" if kind else "" - return f"DELETE{kind} STATISTICS" - - def analyzelistchainedrows_sql(self, expression: exp.AnalyzeListChainedRows) -> str: - inner_expression = self.sql(expression, "expression") - return f"LIST CHAINED ROWS{inner_expression}" - - def analyzevalidate_sql(self, expression: exp.AnalyzeValidate) -> str: - kind = self.sql(expression, "kind") - this = self.sql(expression, "this") - this = f" {this}" if this else "" - inner_expression = self.sql(expression, "expression") - return f"VALIDATE {kind}{this}{inner_expression}" - - def analyze_sql(self, expression: exp.Analyze) -> str: - options = self.expressions(expression, key="options", sep=" ") - options = f" {options}" if options else "" - kind = self.sql(expression, "kind") - kind = f" {kind}" if kind else "" - this = self.sql(expression, "this") - this = f" {this}" if this else "" - mode = self.sql(expression, "mode") - mode = f" {mode}" if mode else "" - properties = self.sql(expression, "properties") - properties = f" {properties}" if properties else "" - partition = self.sql(expression, "partition") - partition = f" {partition}" if partition else "" - inner_expression = self.sql(expression, "expression") - inner_expression = f" {inner_expression}" if inner_expression else "" - return f"ANALYZE{options}{kind}{this}{partition}{mode}{inner_expression}{properties}" - - def xmltable_sql(self, expression: exp.XMLTable) -> str: - this = self.sql(expression, "this") - namespaces = self.expressions(expression, key="namespaces") - namespaces = f"XMLNAMESPACES({namespaces}), " if namespaces else "" - passing = self.expressions(expression, key="passing") - passing = f"{self.sep()}PASSING{self.seg(passing)}" if passing else "" - columns = self.expressions(expression, key="columns") - columns = f"{self.sep()}COLUMNS{self.seg(columns)}" if columns else "" - by_ref = f"{self.sep()}RETURNING SEQUENCE BY REF" if expression.args.get("by_ref") else "" - return f"XMLTABLE({self.sep('')}{self.indent(namespaces + this + passing + by_ref + columns)}{self.seg(')', sep='')}" - - def xmlnamespace_sql(self, expression: exp.XMLNamespace) -> str: - this = self.sql(expression, "this") - return this if isinstance(expression.this, exp.Alias) else f"DEFAULT {this}" - - def export_sql(self, expression: exp.Export) -> str: - this = self.sql(expression, "this") - connection = self.sql(expression, "connection") - connection = f"WITH CONNECTION {connection} " if connection else "" - options = self.sql(expression, "options") - return f"EXPORT DATA {connection}{options} AS {this}" - - def declare_sql(self, expression: exp.Declare) -> str: - return f"DECLARE {self.expressions(expression, flat=True)}" - - def declareitem_sql(self, expression: exp.DeclareItem) -> str: - variable = self.sql(expression, "this") - default = self.sql(expression, "default") - default = f" = {default}" if default else "" - - kind = self.sql(expression, "kind") - if isinstance(expression.args.get("kind"), exp.Schema): - kind = f"TABLE {kind}" - - return f"{variable} AS {kind}{default}" - - def recursivewithsearch_sql(self, expression: exp.RecursiveWithSearch) -> str: - kind = self.sql(expression, "kind") - this = self.sql(expression, "this") - set = self.sql(expression, "expression") - using = self.sql(expression, "using") - using = f" USING {using}" if using else "" - - kind_sql = kind if kind == "CYCLE" else f"SEARCH {kind} FIRST BY" - - return f"{kind_sql} {this} SET {set}{using}" - - def parameterizedagg_sql(self, expression: exp.ParameterizedAgg) -> str: - params = self.expressions(expression, key="params", flat=True) - return self.func(expression.name, *expression.expressions) + f"({params})" - - def anonymousaggfunc_sql(self, expression: exp.AnonymousAggFunc) -> str: - return self.func(expression.name, *expression.expressions) - - def combinedaggfunc_sql(self, expression: exp.CombinedAggFunc) -> str: - return self.anonymousaggfunc_sql(expression) - - def combinedparameterizedagg_sql(self, expression: exp.CombinedParameterizedAgg) -> str: - return self.parameterizedagg_sql(expression) - - def show_sql(self, expression: exp.Show) -> str: - self.unsupported("Unsupported SHOW statement") - return "" - - def get_put_sql(self, expression: exp.Put | exp.Get) -> str: - # Snowflake GET/PUT statements: - # PUT - # GET - props = expression.args.get("properties") - props_sql = self.properties(props, prefix=" ", sep=" ", wrapped=False) if props else "" - this = self.sql(expression, "this") - target = self.sql(expression, "target") - - if isinstance(expression, exp.Put): - return f"PUT {this} {target}{props_sql}" - else: - return f"GET {target} {this}{props_sql}" - - def translatecharacters_sql(self, expression: exp.TranslateCharacters): - this = self.sql(expression, "this") - expr = self.sql(expression, "expression") - with_error = " WITH ERROR" if expression.args.get("with_error") else "" - return f"TRANSLATE({this} USING {expr}{with_error})" diff --git a/altimate_packages/sqlglot/helper.py b/altimate_packages/sqlglot/helper.py deleted file mode 100644 index 63944f954..000000000 --- a/altimate_packages/sqlglot/helper.py +++ /dev/null @@ -1,582 +0,0 @@ -from __future__ import annotations - -import datetime -import inspect -import logging -import re -import sys -import typing as t -from collections.abc import Collection, Set -from contextlib import contextmanager -from copy import copy -from difflib import get_close_matches -from enum import Enum -from itertools import count - -if t.TYPE_CHECKING: - from sqlglot import exp - from sqlglot._typing import A, E, T - from sqlglot.dialects.dialect import DialectType - from sqlglot.expressions import Expression - - -CAMEL_CASE_PATTERN = re.compile("(? t.Any: - return classmethod(self.fget).__get__(None, owner)() # type: ignore - - -def suggest_closest_match_and_fail( - kind: str, - word: str, - possibilities: t.Iterable[str], -) -> None: - close_matches = get_close_matches(word, possibilities, n=1) - - similar = seq_get(close_matches, 0) or "" - if similar: - similar = f" Did you mean {similar}?" - - raise ValueError(f"Unknown {kind} '{word}'.{similar}") - - -def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]: - """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds.""" - try: - return seq[index] - except IndexError: - return None - - -@t.overload -def ensure_list(value: t.Collection[T]) -> t.List[T]: ... - - -@t.overload -def ensure_list(value: None) -> t.List: ... - - -@t.overload -def ensure_list(value: T) -> t.List[T]: ... - - -def ensure_list(value): - """ - Ensures that a value is a list, otherwise casts or wraps it into one. - - Args: - value: The value of interest. - - Returns: - The value cast as a list if it's a list or a tuple, or else the value wrapped in a list. - """ - if value is None: - return [] - if isinstance(value, (list, tuple)): - return list(value) - - return [value] - - -@t.overload -def ensure_collection(value: t.Collection[T]) -> t.Collection[T]: ... - - -@t.overload -def ensure_collection(value: T) -> t.Collection[T]: ... - - -def ensure_collection(value): - """ - Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list. - - Args: - value: The value of interest. - - Returns: - The value if it's a collection, or else the value wrapped in a list. - """ - if value is None: - return [] - return ( - value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value] - ) - - -def csv(*args: str, sep: str = ", ") -> str: - """ - Formats any number of string arguments as CSV. - - Args: - args: The string arguments to format. - sep: The argument separator. - - Returns: - The arguments formatted as a CSV string. - """ - return sep.join(arg for arg in args if arg) - - -def subclasses( - module_name: str, - classes: t.Type | t.Tuple[t.Type, ...], - exclude: t.Type | t.Tuple[t.Type, ...] = (), -) -> t.List[t.Type]: - """ - Returns all subclasses for a collection of classes, possibly excluding some of them. - - Args: - module_name: The name of the module to search for subclasses in. - classes: Class(es) we want to find the subclasses of. - exclude: Class(es) we want to exclude from the returned list. - - Returns: - The target subclasses. - """ - return [ - obj - for _, obj in inspect.getmembers( - sys.modules[module_name], - lambda obj: inspect.isclass(obj) and issubclass(obj, classes) and obj not in exclude, - ) - ] - - -def apply_index_offset( - this: exp.Expression, - expressions: t.List[E], - offset: int, - dialect: DialectType = None, -) -> t.List[E]: - """ - Applies an offset to a given integer literal expression. - - Args: - this: The target of the index. - expressions: The expression the offset will be applied to, wrapped in a list. - offset: The offset that will be applied. - dialect: the dialect of interest. - - Returns: - The original expression with the offset applied to it, wrapped in a list. If the provided - `expressions` argument contains more than one expression, it's returned unaffected. - """ - if not offset or len(expressions) != 1: - return expressions - - expression = expressions[0] - - from sqlglot import exp - from sqlglot.optimizer.annotate_types import annotate_types - from sqlglot.optimizer.simplify import simplify - - if not this.type: - annotate_types(this, dialect=dialect) - - if t.cast(exp.DataType, this.type).this not in ( - exp.DataType.Type.UNKNOWN, - exp.DataType.Type.ARRAY, - ): - return expressions - - if not expression.type: - annotate_types(expression, dialect=dialect) - - if t.cast(exp.DataType, expression.type).this in exp.DataType.INTEGER_TYPES: - logger.info("Applying array index offset (%s)", offset) - expression = simplify(expression + offset) - return [expression] - - return expressions - - -def camel_to_snake_case(name: str) -> str: - """Converts `name` from camelCase to snake_case and returns the result.""" - return CAMEL_CASE_PATTERN.sub("_", name).upper() - - -def while_changing(expression: Expression, func: t.Callable[[Expression], E]) -> E: - """ - Applies a transformation to a given expression until a fix point is reached. - - Args: - expression: The expression to be transformed. - func: The transformation to be applied. - - Returns: - The transformed expression. - """ - end_hash: t.Optional[int] = None - - while True: - # No need to walk the ASTโ€“ we've already cached the hashes in the previous iteration - if end_hash is None: - for n in reversed(tuple(expression.walk())): - n._hash = hash(n) - - start_hash = hash(expression) - expression = func(expression) - - expression_nodes = tuple(expression.walk()) - - # Uncache previous caches so we can recompute them - for n in reversed(expression_nodes): - n._hash = None - n._hash = hash(n) - - end_hash = hash(expression) - - if start_hash == end_hash: - # ... and reset the hash so we don't risk it becoming out of date if a mutation happens - for n in expression_nodes: - n._hash = None - - break - - return expression - - -def tsort(dag: t.Dict[T, t.Set[T]]) -> t.List[T]: - """ - Sorts a given directed acyclic graph in topological order. - - Args: - dag: The graph to be sorted. - - Returns: - A list that contains all of the graph's nodes in topological order. - """ - result = [] - - for node, deps in tuple(dag.items()): - for dep in deps: - if dep not in dag: - dag[dep] = set() - - while dag: - current = {node for node, deps in dag.items() if not deps} - - if not current: - raise ValueError("Cycle error") - - for node in current: - dag.pop(node) - - for deps in dag.values(): - deps -= current - - result.extend(sorted(current)) # type: ignore - - return result - - -def open_file(file_name: str) -> t.TextIO: - """Open a file that may be compressed as gzip and return it in universal newline mode.""" - with open(file_name, "rb") as f: - gzipped = f.read(2) == b"\x1f\x8b" - - if gzipped: - import gzip - - return gzip.open(file_name, "rt", newline="") - - return open(file_name, encoding="utf-8", newline="") - - -@contextmanager -def csv_reader(read_csv: exp.ReadCSV) -> t.Any: - """ - Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`. - - Args: - read_csv: A `ReadCSV` function call. - - Yields: - A python csv reader. - """ - args = read_csv.expressions - file = open_file(read_csv.name) - - delimiter = "," - args = iter(arg.name for arg in args) # type: ignore - for k, v in zip(args, args): - if k == "delimiter": - delimiter = v - - try: - import csv as csv_ - - yield csv_.reader(file, delimiter=delimiter) - finally: - file.close() - - -def find_new_name(taken: t.Collection[str], base: str) -> str: - """ - Searches for a new name. - - Args: - taken: A collection of taken names. - base: Base name to alter. - - Returns: - The new, available name. - """ - if base not in taken: - return base - - i = 2 - new = f"{base}_{i}" - while new in taken: - i += 1 - new = f"{base}_{i}" - - return new - - -def is_int(text: str) -> bool: - return is_type(text, int) - - -def is_float(text: str) -> bool: - return is_type(text, float) - - -def is_type(text: str, target_type: t.Type) -> bool: - try: - target_type(text) - return True - except ValueError: - return False - - -def name_sequence(prefix: str) -> t.Callable[[], str]: - """Returns a name generator given a prefix (e.g. a0, a1, a2, ... if the prefix is "a").""" - sequence = count() - return lambda: f"{prefix}{next(sequence)}" - - -def object_to_dict(obj: t.Any, **kwargs) -> t.Dict: - """Returns a dictionary created from an object's attributes.""" - return { - **{k: v.copy() if hasattr(v, "copy") else copy(v) for k, v in vars(obj).items()}, - **kwargs, - } - - -def split_num_words( - value: str, sep: str, min_num_words: int, fill_from_start: bool = True -) -> t.List[t.Optional[str]]: - """ - Perform a split on a value and return N words as a result with `None` used for words that don't exist. - - Args: - value: The value to be split. - sep: The value to use to split on. - min_num_words: The minimum number of words that are going to be in the result. - fill_from_start: Indicates that if `None` values should be inserted at the start or end of the list. - - Examples: - >>> split_num_words("db.table", ".", 3) - [None, 'db', 'table'] - >>> split_num_words("db.table", ".", 3, fill_from_start=False) - ['db', 'table', None] - >>> split_num_words("db.table", ".", 1) - ['db', 'table'] - - Returns: - The list of words returned by `split`, possibly augmented by a number of `None` values. - """ - words = value.split(sep) - if fill_from_start: - return [None] * (min_num_words - len(words)) + words - return words + [None] * (min_num_words - len(words)) - - -def is_iterable(value: t.Any) -> bool: - """ - Checks if the value is an iterable, excluding the types `str` and `bytes`. - - Examples: - >>> is_iterable([1,2]) - True - >>> is_iterable("test") - False - - Args: - value: The value to check if it is an iterable. - - Returns: - A `bool` value indicating if it is an iterable. - """ - from sqlglot import Expression - - return hasattr(value, "__iter__") and not isinstance(value, (str, bytes, Expression)) - - -def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Iterator[t.Any]: - """ - Flattens an iterable that can contain both iterable and non-iterable elements. Objects of - type `str` and `bytes` are not regarded as iterables. - - Examples: - >>> list(flatten([[1, 2], 3, {4}, (5, "bla")])) - [1, 2, 3, 4, 5, 'bla'] - >>> list(flatten([1, 2, 3])) - [1, 2, 3] - - Args: - values: The value to be flattened. - - Yields: - Non-iterable elements in `values`. - """ - for value in values: - if is_iterable(value): - yield from flatten(value) - else: - yield value - - -def dict_depth(d: t.Dict) -> int: - """ - Get the nesting depth of a dictionary. - - Example: - >>> dict_depth(None) - 0 - >>> dict_depth({}) - 1 - >>> dict_depth({"a": "b"}) - 1 - >>> dict_depth({"a": {}}) - 2 - >>> dict_depth({"a": {"b": {}}}) - 3 - """ - try: - return 1 + dict_depth(next(iter(d.values()))) - except AttributeError: - # d doesn't have attribute "values" - return 0 - except StopIteration: - # d.values() returns an empty sequence - return 1 - - -def first(it: t.Iterable[T]) -> T: - """Returns the first element from an iterable (useful for sets).""" - return next(i for i in it) - - -def to_bool(value: t.Optional[str | bool]) -> t.Optional[str | bool]: - if isinstance(value, bool) or value is None: - return value - - # Coerce the value to boolean if it matches to the truthy/falsy values below - value_lower = value.lower() - if value_lower in ("true", "1"): - return True - if value_lower in ("false", "0"): - return False - - return value - - -def merge_ranges(ranges: t.List[t.Tuple[A, A]]) -> t.List[t.Tuple[A, A]]: - """ - Merges a sequence of ranges, represented as tuples (low, high) whose values - belong to some totally-ordered set. - - Example: - >>> merge_ranges([(1, 3), (2, 6)]) - [(1, 6)] - """ - if not ranges: - return [] - - ranges = sorted(ranges) - - merged = [ranges[0]] - - for start, end in ranges[1:]: - last_start, last_end = merged[-1] - - if start <= last_end: - merged[-1] = (last_start, max(last_end, end)) - else: - merged.append((start, end)) - - return merged - - -def is_iso_date(text: str) -> bool: - try: - datetime.date.fromisoformat(text) - return True - except ValueError: - return False - - -def is_iso_datetime(text: str) -> bool: - try: - datetime.datetime.fromisoformat(text) - return True - except ValueError: - return False - - -# Interval units that operate on date components -DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} - - -def is_date_unit(expression: t.Optional[exp.Expression]) -> bool: - return expression is not None and expression.name.lower() in DATE_UNITS - - -K = t.TypeVar("K") -V = t.TypeVar("V") - - -class SingleValuedMapping(t.Mapping[K, V]): - """ - Mapping where all keys return the same value. - - This rigamarole is meant to avoid copying keys, which was originally intended - as an optimization while qualifying columns for tables with lots of columns. - """ - - def __init__(self, keys: t.Collection[K], value: V): - self._keys = keys if isinstance(keys, Set) else set(keys) - self._value = value - - def __getitem__(self, key: K) -> V: - if key in self._keys: - return self._value - raise KeyError(key) - - def __len__(self) -> int: - return len(self._keys) - - def __iter__(self) -> t.Iterator[K]: - return iter(self._keys) diff --git a/altimate_packages/sqlglot/jsonpath.py b/altimate_packages/sqlglot/jsonpath.py deleted file mode 100644 index 115bd1594..000000000 --- a/altimate_packages/sqlglot/jsonpath.py +++ /dev/null @@ -1,227 +0,0 @@ -from __future__ import annotations - -import typing as t - -import sqlglot.expressions as exp -from sqlglot.errors import ParseError -from sqlglot.tokens import Token, Tokenizer, TokenType - -if t.TYPE_CHECKING: - from sqlglot._typing import Lit - from sqlglot.dialects.dialect import DialectType - - -class JSONPathTokenizer(Tokenizer): - SINGLE_TOKENS = { - "(": TokenType.L_PAREN, - ")": TokenType.R_PAREN, - "[": TokenType.L_BRACKET, - "]": TokenType.R_BRACKET, - ":": TokenType.COLON, - ",": TokenType.COMMA, - "-": TokenType.DASH, - ".": TokenType.DOT, - "?": TokenType.PLACEHOLDER, - "@": TokenType.PARAMETER, - "'": TokenType.QUOTE, - '"': TokenType.QUOTE, - "$": TokenType.DOLLAR, - "*": TokenType.STAR, - } - - KEYWORDS = { - "..": TokenType.DOT, - } - - IDENTIFIER_ESCAPES = ["\\"] - STRING_ESCAPES = ["\\"] - - -def parse(path: str, dialect: DialectType = None) -> exp.JSONPath: - """Takes in a JSON path string and parses it into a JSONPath expression.""" - from sqlglot.dialects import Dialect - - jsonpath_tokenizer = Dialect.get_or_raise(dialect).jsonpath_tokenizer - tokens = jsonpath_tokenizer.tokenize(path) - size = len(tokens) - - i = 0 - - def _curr() -> t.Optional[TokenType]: - return tokens[i].token_type if i < size else None - - def _prev() -> Token: - return tokens[i - 1] - - def _advance() -> Token: - nonlocal i - i += 1 - return _prev() - - def _error(msg: str) -> str: - return f"{msg} at index {i}: {path}" - - @t.overload - def _match(token_type: TokenType, raise_unmatched: Lit[True] = True) -> Token: - pass - - @t.overload - def _match(token_type: TokenType, raise_unmatched: Lit[False] = False) -> t.Optional[Token]: - pass - - def _match(token_type, raise_unmatched=False): - if _curr() == token_type: - return _advance() - if raise_unmatched: - raise ParseError(_error(f"Expected {token_type}")) - return None - - def _parse_literal() -> t.Any: - token = _match(TokenType.STRING) or _match(TokenType.IDENTIFIER) - if token: - return token.text - if _match(TokenType.STAR): - return exp.JSONPathWildcard() - if _match(TokenType.PLACEHOLDER) or _match(TokenType.L_PAREN): - script = _prev().text == "(" - start = i - - while True: - if _match(TokenType.L_BRACKET): - _parse_bracket() # nested call which we can throw away - if _curr() in (TokenType.R_BRACKET, None): - break - _advance() - - expr_type = exp.JSONPathScript if script else exp.JSONPathFilter - return expr_type(this=path[tokens[start].start : tokens[i].end]) - - number = "-" if _match(TokenType.DASH) else "" - - token = _match(TokenType.NUMBER) - if token: - number += token.text - - if number: - return int(number) - - return False - - def _parse_slice() -> t.Any: - start = _parse_literal() - end = _parse_literal() if _match(TokenType.COLON) else None - step = _parse_literal() if _match(TokenType.COLON) else None - - if end is None and step is None: - return start - - return exp.JSONPathSlice(start=start, end=end, step=step) - - def _parse_bracket() -> exp.JSONPathPart: - literal = _parse_slice() - - if isinstance(literal, str) or literal is not False: - indexes = [literal] - while _match(TokenType.COMMA): - literal = _parse_slice() - - if literal: - indexes.append(literal) - - if len(indexes) == 1: - if isinstance(literal, str): - node: exp.JSONPathPart = exp.JSONPathKey(this=indexes[0]) - elif isinstance(literal, exp.JSONPathPart) and isinstance( - literal, (exp.JSONPathScript, exp.JSONPathFilter) - ): - node = exp.JSONPathSelector(this=indexes[0]) - else: - node = exp.JSONPathSubscript(this=indexes[0]) - else: - node = exp.JSONPathUnion(expressions=indexes) - else: - raise ParseError(_error("Cannot have empty segment")) - - _match(TokenType.R_BRACKET, raise_unmatched=True) - - return node - - def _parse_var_text() -> str: - """ - Consumes & returns the text for a var. In BigQuery it's valid to have a key with spaces - in it, e.g JSON_QUERY(..., '$. a b c ') should produce a single JSONPathKey(' a b c '). - This is done by merging "consecutive" vars until a key separator is found (dot, colon etc) - or the path string is exhausted. - """ - prev_index = i - 2 - - while _match(TokenType.VAR): - pass - - start = 0 if prev_index < 0 else tokens[prev_index].end + 1 - - if i >= len(tokens): - # This key is the last token for the path, so it's text is the remaining path - text = path[start:] - else: - text = path[start : tokens[i].start] - - return text - - # We canonicalize the JSON path AST so that it always starts with a - # "root" element, so paths like "field" will be generated as "$.field" - _match(TokenType.DOLLAR) - expressions: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] - - while _curr(): - if _match(TokenType.DOT) or _match(TokenType.COLON): - recursive = _prev().text == ".." - - if _match(TokenType.VAR): - value: t.Optional[str | exp.JSONPathWildcard] = _parse_var_text() - elif _match(TokenType.IDENTIFIER): - value = _prev().text - elif _match(TokenType.STAR): - value = exp.JSONPathWildcard() - else: - value = None - - if recursive: - expressions.append(exp.JSONPathRecursive(this=value)) - elif value: - expressions.append(exp.JSONPathKey(this=value)) - else: - raise ParseError(_error("Expected key name or * after DOT")) - elif _match(TokenType.L_BRACKET): - expressions.append(_parse_bracket()) - elif _match(TokenType.VAR): - expressions.append(exp.JSONPathKey(this=_parse_var_text())) - elif _match(TokenType.IDENTIFIER): - expressions.append(exp.JSONPathKey(this=_prev().text)) - elif _match(TokenType.STAR): - expressions.append(exp.JSONPathWildcard()) - else: - raise ParseError(_error(f"Unexpected {tokens[i].token_type}")) - - return exp.JSONPath(expressions=expressions) - - -JSON_PATH_PART_TRANSFORMS: t.Dict[t.Type[exp.Expression], t.Callable[..., str]] = { - exp.JSONPathFilter: lambda _, e: f"?{e.this}", - exp.JSONPathKey: lambda self, e: self._jsonpathkey_sql(e), - exp.JSONPathRecursive: lambda _, e: f"..{e.this or ''}", - exp.JSONPathRoot: lambda *_: "$", - exp.JSONPathScript: lambda _, e: f"({e.this}", - exp.JSONPathSelector: lambda self, e: f"[{self.json_path_part(e.this)}]", - exp.JSONPathSlice: lambda self, e: ":".join( - "" if p is False else self.json_path_part(p) - for p in [e.args.get("start"), e.args.get("end"), e.args.get("step")] - if p is not None - ), - exp.JSONPathSubscript: lambda self, e: self._jsonpathsubscript_sql(e), - exp.JSONPathUnion: lambda self, - e: f"[{','.join(self.json_path_part(p) for p in e.expressions)}]", - exp.JSONPathWildcard: lambda *_: "*", -} - -ALL_JSON_PATH_PARTS = set(JSON_PATH_PART_TRANSFORMS) diff --git a/altimate_packages/sqlglot/lineage.py b/altimate_packages/sqlglot/lineage.py deleted file mode 100644 index fd5e75ad3..000000000 --- a/altimate_packages/sqlglot/lineage.py +++ /dev/null @@ -1,423 +0,0 @@ -from __future__ import annotations - -import json -import logging -import typing as t -from dataclasses import dataclass, field - -from sqlglot import Schema, exp, maybe_parse -from sqlglot.errors import SqlglotError -from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify -from sqlglot.optimizer.scope import ScopeType - -if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import DialectType - -logger = logging.getLogger("sqlglot") - - -@dataclass(frozen=True) -class Node: - name: str - expression: exp.Expression - source: exp.Expression - downstream: t.List[Node] = field(default_factory=list) - source_name: str = "" - reference_node_name: str = "" - - def walk(self) -> t.Iterator[Node]: - yield self - - for d in self.downstream: - yield from d.walk() - - def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: - nodes = {} - edges = [] - - for node in self.walk(): - if isinstance(node.expression, exp.Table): - label = f"FROM {node.expression.this}" - title = f"
SELECT {node.name} FROM {node.expression.this}
" - group = 1 - else: - label = node.expression.sql(pretty=True, dialect=dialect) - source = node.source.transform( - lambda n: ( - exp.Tag(this=n, prefix="", postfix="") if n is node.expression else n - ), - copy=False, - ).sql(pretty=True, dialect=dialect) - title = f"
{source}
" - group = 0 - - node_id = id(node) - - nodes[node_id] = { - "id": node_id, - "label": label, - "title": title, - "group": group, - } - - for d in node.downstream: - edges.append({"from": node_id, "to": id(d)}) - return GraphHTML(nodes, edges, **opts) - - -def lineage( - column: str | exp.Column, - sql: str | exp.Expression, - schema: t.Optional[t.Dict | Schema] = None, - sources: t.Optional[t.Mapping[str, str | exp.Query]] = None, - dialect: DialectType = None, - scope: t.Optional[Scope] = None, - trim_selects: bool = True, - **kwargs, -) -> Node: - """Build the lineage graph for a column of a SQL query. - - Args: - column: The column to build the lineage for. - sql: The SQL string or expression. - schema: The schema of tables. - sources: A mapping of queries which will be used to continue building lineage. - dialect: The dialect of input SQL. - scope: A pre-created scope to use instead. - trim_selects: Whether or not to clean up selects by trimming to only relevant columns. - **kwargs: Qualification optimizer kwargs. - - Returns: - A lineage node. - """ - - expression = maybe_parse(sql, dialect=dialect) - column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name - - if sources: - expression = exp.expand( - expression, - {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()}, - dialect=dialect, - ) - - if not scope: - expression = qualify.qualify( - expression, - dialect=dialect, - schema=schema, - **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore - ) - - scope = build_scope(expression) - - if not scope: - raise SqlglotError("Cannot build lineage, sql must be SELECT") - - if not any(select.alias_or_name == column for select in scope.expression.selects): - raise SqlglotError(f"Cannot find column '{column}' in query.") - - return to_node(column, scope, dialect, trim_selects=trim_selects) - - -def to_node( - column: str | int, - scope: Scope, - dialect: DialectType, - scope_name: t.Optional[str] = None, - upstream: t.Optional[Node] = None, - source_name: t.Optional[str] = None, - reference_node_name: t.Optional[str] = None, - trim_selects: bool = True, -) -> Node: - # Find the specific select clause that is the source of the column we want. - # This can either be a specific, named select or a generic `*` clause. - select = ( - scope.expression.selects[column] - if isinstance(column, int) - else next( - (select for select in scope.expression.selects if select.alias_or_name == column), - exp.Star() if scope.expression.is_star else scope.expression, - ) - ) - - if isinstance(scope.expression, exp.Subquery): - for source in scope.subquery_scopes: - return to_node( - column, - scope=source, - dialect=dialect, - upstream=upstream, - source_name=source_name, - reference_node_name=reference_node_name, - trim_selects=trim_selects, - ) - if isinstance(scope.expression, exp.SetOperation): - name = type(scope.expression).__name__.upper() - upstream = upstream or Node(name=name, source=scope.expression, expression=select) - - index = ( - column - if isinstance(column, int) - else next( - ( - i - for i, select in enumerate(scope.expression.selects) - if select.alias_or_name == column or select.is_star - ), - -1, # mypy will not allow a None here, but a negative index should never be returned - ) - ) - - if index == -1: - raise ValueError(f"Could not find {column} in {scope.expression}") - - for s in scope.union_scopes: - to_node( - index, - scope=s, - dialect=dialect, - upstream=upstream, - source_name=source_name, - reference_node_name=reference_node_name, - trim_selects=trim_selects, - ) - - return upstream - - if trim_selects and isinstance(scope.expression, exp.Select): - # For better ergonomics in our node labels, replace the full select with - # a version that has only the column we care about. - # "x", SELECT x, y FROM foo - # => "x", SELECT x FROM foo - source = t.cast(exp.Expression, scope.expression.select(select, append=False)) - else: - source = scope.expression - - # Create the node for this step in the lineage chain, and attach it to the previous one. - node = Node( - name=f"{scope_name}.{column}" if scope_name else str(column), - source=source, - expression=select, - source_name=source_name or "", - reference_node_name=reference_node_name or "", - ) - - if upstream: - upstream.downstream.append(node) - - subquery_scopes = { - id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes - } - - for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): - subquery_scope = subquery_scopes.get(id(subquery)) - if not subquery_scope: - logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") - continue - - for name in subquery.named_selects: - to_node( - name, - scope=subquery_scope, - dialect=dialect, - upstream=node, - trim_selects=trim_selects, - ) - - # if the select is a star add all scope sources as downstreams - if select.is_star: - for source in scope.sources.values(): - if isinstance(source, Scope): - source = source.expression - node.downstream.append( - Node(name=select.sql(comments=False), source=source, expression=source) - ) - - # Find all columns that went into creating this one to list their lineage nodes. - source_columns = set(find_all_in_scope(select, exp.Column)) - - # If the source is a UDTF find columns used in the UDTF to generate the table - if isinstance(source, exp.UDTF): - source_columns |= set(source.find_all(exp.Column)) - derived_tables = [ - source.expression.parent - for source in scope.sources.values() - if isinstance(source, Scope) and source.is_derived_table - ] - else: - derived_tables = scope.derived_tables - - source_names = { - dt.alias: dt.comments[0].split()[1] - for dt in derived_tables - if dt.comments and dt.comments[0].startswith("source: ") - } - - pivots = scope.pivots - pivot = pivots[0] if len(pivots) == 1 and not pivots[0].unpivot else None - if pivot: - # For each aggregation function, the pivot creates a new column for each field in category - # combined with the aggfunc. So the columns parsed have this order: cat_a_value_sum, cat_a, - # b_value_sum, b. Because of this step wise manner the aggfunc 'sum(value) as value_sum' - # belongs to the column indices 0, 2, and the aggfunc 'max(price)' without an alias belongs - # to the column indices 1, 3. Here, only the columns used in the aggregations are of interest - # in the lineage, so lookup the pivot column name by index and map that with the columns used - # in the aggregation. - # - # Example: PIVOT (SUM(value) AS value_sum, MAX(price)) FOR category IN ('a' AS cat_a, 'b') - pivot_columns = pivot.args["columns"] - pivot_aggs_count = len(pivot.expressions) - - pivot_column_mapping = {} - for i, agg in enumerate(pivot.expressions): - agg_cols = list(agg.find_all(exp.Column)) - for col_index in range(i, len(pivot_columns), pivot_aggs_count): - pivot_column_mapping[pivot_columns[col_index].name] = agg_cols - - for c in source_columns: - table = c.table - source = scope.sources.get(table) - - if isinstance(source, Scope): - reference_node_name = None - if source.scope_type == ScopeType.DERIVED_TABLE and table not in source_names: - reference_node_name = table - elif source.scope_type == ScopeType.CTE: - selected_node, _ = scope.selected_sources.get(table, (None, None)) - reference_node_name = selected_node.name if selected_node else None - - # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. - to_node( - c.name, - scope=source, - dialect=dialect, - scope_name=table, - upstream=node, - source_name=source_names.get(table) or source_name, - reference_node_name=reference_node_name, - trim_selects=trim_selects, - ) - elif pivot and pivot.alias_or_name == c.table: - downstream_columns = [] - - column_name = c.name - if any(column_name == pivot_column.name for pivot_column in pivot_columns): - downstream_columns.extend(pivot_column_mapping[column_name]) - else: - # The column is not in the pivot, so it must be an implicit column of the - # pivoted source -- adapt column to be from the implicit pivoted source. - downstream_columns.append(exp.column(c.this, table=pivot.parent.this)) - - for downstream_column in downstream_columns: - table = downstream_column.table - source = scope.sources.get(table) - if isinstance(source, Scope): - to_node( - downstream_column.name, - scope=source, - scope_name=table, - dialect=dialect, - upstream=node, - source_name=source_names.get(table) or source_name, - reference_node_name=reference_node_name, - trim_selects=trim_selects, - ) - else: - source = source or exp.Placeholder() - node.downstream.append( - Node( - name=downstream_column.sql(comments=False), - source=source, - expression=source, - ) - ) - else: - # The source is not a scope and the column is not in any pivot - we've reached the end - # of the line. At this point, if a source is not found it means this column's lineage - # is unknown. This can happen if the definition of a source used in a query is not - # passed into the `sources` map. - source = source or exp.Placeholder() - node.downstream.append( - Node(name=c.sql(comments=False), source=source, expression=source) - ) - - return node - - -class GraphHTML: - """Node to HTML generator using vis.js. - - https://visjs.github.io/vis-network/docs/network/ - """ - - def __init__( - self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None - ): - self.imports = imports - - self.options = { - "height": "500px", - "width": "100%", - "layout": { - "hierarchical": { - "enabled": True, - "nodeSpacing": 200, - "sortMethod": "directed", - }, - }, - "interaction": { - "dragNodes": False, - "selectable": False, - }, - "physics": { - "enabled": False, - }, - "edges": { - "arrows": "to", - }, - "nodes": { - "font": "20px monaco", - "shape": "box", - "widthConstraint": { - "maximum": 300, - }, - }, - **(options or {}), - } - - self.nodes = nodes - self.edges = edges - - def __str__(self): - nodes = json.dumps(list(self.nodes.values())) - edges = json.dumps(self.edges) - options = json.dumps(self.options) - imports = ( - """ - - """ - if self.imports - else "" - ) - - return f"""
-
- {imports} - -
""" - - def _repr_html_(self) -> str: - return self.__str__() diff --git a/altimate_packages/sqlglot/optimizer/__init__.py b/altimate_packages/sqlglot/optimizer/__init__.py deleted file mode 100644 index 050f246c9..000000000 --- a/altimate_packages/sqlglot/optimizer/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# ruff: noqa: F401 - -from sqlglot.optimizer.optimizer import RULES as RULES, optimize as optimize -from sqlglot.optimizer.scope import ( - Scope as Scope, - build_scope as build_scope, - find_all_in_scope as find_all_in_scope, - find_in_scope as find_in_scope, - traverse_scope as traverse_scope, - walk_in_scope as walk_in_scope, -) diff --git a/altimate_packages/sqlglot/optimizer/annotate_types.py b/altimate_packages/sqlglot/optimizer/annotate_types.py deleted file mode 100644 index d460b1cc0..000000000 --- a/altimate_packages/sqlglot/optimizer/annotate_types.py +++ /dev/null @@ -1,589 +0,0 @@ -from __future__ import annotations - -import functools -import typing as t - -from sqlglot import exp -from sqlglot.helper import ( - ensure_list, - is_date_unit, - is_iso_date, - is_iso_datetime, - seq_get, -) -from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.schema import Schema, ensure_schema -from sqlglot.dialects.dialect import Dialect - -if t.TYPE_CHECKING: - from sqlglot._typing import B, E - - BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] - BinaryCoercions = t.Dict[ - t.Tuple[exp.DataType.Type, exp.DataType.Type], - BinaryCoercionFunc, - ] - - from sqlglot.dialects.dialect import DialectType, AnnotatorsType - - -def annotate_types( - expression: E, - schema: t.Optional[t.Dict | Schema] = None, - annotators: t.Optional[AnnotatorsType] = None, - coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, - dialect: DialectType = None, -) -> E: - """ - Infers the types of an expression, annotating its AST accordingly. - - Example: - >>> import sqlglot - >>> schema = {"y": {"cola": "SMALLINT"}} - >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" - >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) - >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" - - - Args: - expression: Expression to annotate. - schema: Database schema. - annotators: Maps expression type to corresponding annotation function. - coerces_to: Maps expression type to set of types that it can be coerced into. - - Returns: - The expression annotated with types. - """ - - schema = ensure_schema(schema, dialect=dialect) - - return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) - - -def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: - date_text = l.name - is_iso_date_ = is_iso_date(date_text) - - if is_iso_date_ and is_date_unit(unit): - return exp.DataType.Type.DATE - - # An ISO date is also an ISO datetime, but not vice versa - if is_iso_date_ or is_iso_datetime(date_text): - return exp.DataType.Type.DATETIME - - return exp.DataType.Type.UNKNOWN - - -def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: - if not is_date_unit(unit): - return exp.DataType.Type.DATETIME - return l.type.this if l.type else exp.DataType.Type.UNKNOWN - - -def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc: - @functools.wraps(func) - def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: - return func(r, l) - - return _swapped - - -def swap_all(coercions: BinaryCoercions) -> BinaryCoercions: - return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}} - - -class _TypeAnnotator(type): - def __new__(cls, clsname, bases, attrs): - klass = super().__new__(cls, clsname, bases, attrs) - - # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): - # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html - text_precedence = ( - exp.DataType.Type.TEXT, - exp.DataType.Type.NVARCHAR, - exp.DataType.Type.VARCHAR, - exp.DataType.Type.NCHAR, - exp.DataType.Type.CHAR, - ) - numeric_precedence = ( - exp.DataType.Type.DOUBLE, - exp.DataType.Type.FLOAT, - exp.DataType.Type.DECIMAL, - exp.DataType.Type.BIGINT, - exp.DataType.Type.INT, - exp.DataType.Type.SMALLINT, - exp.DataType.Type.TINYINT, - ) - timelike_precedence = ( - exp.DataType.Type.TIMESTAMPLTZ, - exp.DataType.Type.TIMESTAMPTZ, - exp.DataType.Type.TIMESTAMP, - exp.DataType.Type.DATETIME, - exp.DataType.Type.DATE, - ) - - for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): - coerces_to = set() - for data_type in type_precedence: - klass.COERCES_TO[data_type] = coerces_to.copy() - coerces_to |= {data_type} - - # NULL can be coerced to any type, so e.g. NULL + 1 will have type INT - klass.COERCES_TO[exp.DataType.Type.NULL] = { - *text_precedence, - *numeric_precedence, - *timelike_precedence, - } - - return klass - - -class TypeAnnotator(metaclass=_TypeAnnotator): - NESTED_TYPES = { - exp.DataType.Type.ARRAY, - } - - # Specifies what types a given type can be coerced into (autofilled) - COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} - - # Coercion functions for binary operations. - # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. - BINARY_COERCIONS: BinaryCoercions = { - **swap_all( - { - (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal( - l, r.args.get("unit") - ) - for t in exp.DataType.TEXT_TYPES - } - ), - **swap_all( - { - # text + numeric will yield the numeric type to match most dialects' semantics - (text, numeric): lambda l, r: t.cast( - exp.DataType.Type, l.type if l.type in exp.DataType.NUMERIC_TYPES else r.type - ) - for text in exp.DataType.TEXT_TYPES - for numeric in exp.DataType.NUMERIC_TYPES - } - ), - **swap_all( - { - (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date( - l, r.args.get("unit") - ), - } - ), - } - - def __init__( - self, - schema: Schema, - annotators: t.Optional[AnnotatorsType] = None, - coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, - binary_coercions: t.Optional[BinaryCoercions] = None, - ) -> None: - self.schema = schema - self.annotators = annotators or Dialect.get_or_raise(schema.dialect).ANNOTATORS - self.coerces_to = ( - coerces_to or Dialect.get_or_raise(schema.dialect).COERCES_TO or self.COERCES_TO - ) - self.binary_coercions = binary_coercions or self.BINARY_COERCIONS - - # Caches the ids of annotated sub-Expressions, to ensure we only visit them once - self._visited: t.Set[int] = set() - - # Maps an exp.SetOperation's id (e.g. UNION) to its projection types. This is computed if the - # exp.SetOperation is the expression of a scope source, as selecting from it multiple times - # would reprocess the entire subtree to coerce the types of its operands' projections - self._setop_column_types: t.Dict[int, t.Dict[str, exp.DataType | exp.DataType.Type]] = {} - - def _set_type( - self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type] - ) -> None: - expression.type = target_type or exp.DataType.Type.UNKNOWN # type: ignore - self._visited.add(id(expression)) - - def annotate(self, expression: E) -> E: - for scope in traverse_scope(expression): - self.annotate_scope(scope) - return self._maybe_annotate(expression) # This takes care of non-traversable expressions - - def annotate_scope(self, scope: Scope) -> None: - selects = {} - for name, source in scope.sources.items(): - if not isinstance(source, Scope): - continue - - expression = source.expression - if isinstance(expression, exp.UDTF): - values = [] - - if isinstance(expression, exp.Lateral): - if isinstance(expression.this, exp.Explode): - values = [expression.this.this] - elif isinstance(expression, exp.Unnest): - values = [expression] - elif not isinstance(expression, exp.TableFromRows): - values = expression.expressions[0].expressions - - if not values: - continue - - selects[name] = { - alias: column.type - for alias, column in zip(expression.alias_column_names, values) - } - elif isinstance(expression, exp.SetOperation) and len(expression.left.selects) == len( - expression.right.selects - ): - selects[name] = col_types = self._setop_column_types.setdefault(id(expression), {}) - - if not col_types: - # Process a chain / sub-tree of set operations - for set_op in expression.walk( - prune=lambda n: not isinstance(n, (exp.SetOperation, exp.Subquery)) - ): - if not isinstance(set_op, exp.SetOperation): - continue - - if set_op.args.get("by_name"): - r_type_by_select = { - s.alias_or_name: s.type for s in set_op.right.selects - } - setop_cols = { - s.alias_or_name: self._maybe_coerce( - t.cast(exp.DataType, s.type), - r_type_by_select.get(s.alias_or_name) - or exp.DataType.Type.UNKNOWN, - ) - for s in set_op.left.selects - } - else: - setop_cols = { - ls.alias_or_name: self._maybe_coerce( - t.cast(exp.DataType, ls.type), t.cast(exp.DataType, rs.type) - ) - for ls, rs in zip(set_op.left.selects, set_op.right.selects) - } - - # Coerce intermediate results with the previously registered types, if they exist - for col_name, col_type in setop_cols.items(): - col_types[col_name] = self._maybe_coerce( - col_type, col_types.get(col_name, exp.DataType.Type.NULL) - ) - - else: - selects[name] = {s.alias_or_name: s.type for s in expression.selects} - - # First annotate the current scope's column references - for col in scope.columns: - if not col.table: - continue - - source = scope.sources.get(col.table) - if isinstance(source, exp.Table): - self._set_type(col, self.schema.get_column_type(source, col)) - elif source: - if col.table in selects and col.name in selects[col.table]: - self._set_type(col, selects[col.table][col.name]) - elif isinstance(source.expression, exp.Unnest): - self._set_type(col, source.expression.type) - - # Then (possibly) annotate the remaining expressions in the scope - self._maybe_annotate(scope.expression) - - def _maybe_annotate(self, expression: E) -> E: - if id(expression) in self._visited: - return expression # We've already inferred the expression's type - - annotator = self.annotators.get(expression.__class__) - - return ( - annotator(self, expression) - if annotator - else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) - ) - - def _annotate_args(self, expression: E) -> E: - for value in expression.iter_expressions(): - self._maybe_annotate(value) - - return expression - - def _maybe_coerce( - self, - type1: exp.DataType | exp.DataType.Type, - type2: exp.DataType | exp.DataType.Type, - ) -> exp.DataType | exp.DataType.Type: - """ - Returns type2 if type1 can be coerced into it, otherwise type1. - - If either type is parameterized (e.g. DECIMAL(18, 2) contains two parameters), - we assume type1 does not coerce into type2, so we also return it in this case. - """ - if isinstance(type1, exp.DataType): - if type1.expressions: - return type1 - type1_value = type1.this - else: - type1_value = type1 - - if isinstance(type2, exp.DataType): - if type2.expressions: - return type2 - type2_value = type2.this - else: - type2_value = type2 - - # We propagate the UNKNOWN type upwards if found - if exp.DataType.Type.UNKNOWN in (type1_value, type2_value): - return exp.DataType.Type.UNKNOWN - - return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value - - def _annotate_binary(self, expression: B) -> B: - self._annotate_args(expression) - - left, right = expression.left, expression.right - left_type, right_type = left.type.this, right.type.this # type: ignore - - if isinstance(expression, (exp.Connector, exp.Predicate)): - self._set_type(expression, exp.DataType.Type.BOOLEAN) - elif (left_type, right_type) in self.binary_coercions: - self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) - else: - self._set_type(expression, self._maybe_coerce(left_type, right_type)) - - return expression - - def _annotate_unary(self, expression: E) -> E: - self._annotate_args(expression) - - if isinstance(expression, exp.Not): - self._set_type(expression, exp.DataType.Type.BOOLEAN) - else: - self._set_type(expression, expression.this.type) - - return expression - - def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: - if expression.is_string: - self._set_type(expression, exp.DataType.Type.VARCHAR) - elif expression.is_int: - self._set_type(expression, exp.DataType.Type.INT) - else: - self._set_type(expression, exp.DataType.Type.DOUBLE) - - return expression - - def _annotate_with_type( - self, expression: E, target_type: exp.DataType | exp.DataType.Type - ) -> E: - self._set_type(expression, target_type) - return self._annotate_args(expression) - - @t.no_type_check - def _annotate_by_args( - self, - expression: E, - *args: str, - promote: bool = False, - array: bool = False, - ) -> E: - self._annotate_args(expression) - - expressions: t.List[exp.Expression] = [] - for arg in args: - arg_expr = expression.args.get(arg) - expressions.extend(expr for expr in ensure_list(arg_expr) if expr) - - last_datatype = None - for expr in expressions: - expr_type = expr.type - - # Stop at the first nested data type found - we don't want to _maybe_coerce nested types - if expr_type.args.get("nested"): - last_datatype = expr_type - break - - if not expr_type.is_type(exp.DataType.Type.UNKNOWN): - last_datatype = self._maybe_coerce(last_datatype or expr_type, expr_type) - - self._set_type(expression, last_datatype or exp.DataType.Type.UNKNOWN) - - if promote: - if expression.type.this in exp.DataType.INTEGER_TYPES: - self._set_type(expression, exp.DataType.Type.BIGINT) - elif expression.type.this in exp.DataType.FLOAT_TYPES: - self._set_type(expression, exp.DataType.Type.DOUBLE) - - if array: - self._set_type( - expression, - exp.DataType( - this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True - ), - ) - - return expression - - def _annotate_timeunit( - self, expression: exp.TimeUnit | exp.DateTrunc - ) -> exp.TimeUnit | exp.DateTrunc: - self._annotate_args(expression) - - if expression.this.type.this in exp.DataType.TEXT_TYPES: - datatype = _coerce_date_literal(expression.this, expression.unit) - elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES: - datatype = _coerce_date(expression.this, expression.unit) - else: - datatype = exp.DataType.Type.UNKNOWN - - self._set_type(expression, datatype) - return expression - - def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket: - self._annotate_args(expression) - - bracket_arg = expression.expressions[0] - this = expression.this - - if isinstance(bracket_arg, exp.Slice): - self._set_type(expression, this.type) - elif this.type.is_type(exp.DataType.Type.ARRAY): - self._set_type(expression, seq_get(this.type.expressions, 0)) - elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: - index = this.keys.index(bracket_arg) - value = seq_get(this.values, index) - self._set_type(expression, value.type if value else None) - else: - self._set_type(expression, exp.DataType.Type.UNKNOWN) - - return expression - - def _annotate_div(self, expression: exp.Div) -> exp.Div: - self._annotate_args(expression) - - left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore - - if ( - expression.args.get("typed") - and left_type in exp.DataType.INTEGER_TYPES - and right_type in exp.DataType.INTEGER_TYPES - ): - self._set_type(expression, exp.DataType.Type.BIGINT) - else: - self._set_type(expression, self._maybe_coerce(left_type, right_type)) - if expression.type and expression.type.this not in exp.DataType.REAL_TYPES: - self._set_type( - expression, self._maybe_coerce(expression.type, exp.DataType.Type.DOUBLE) - ) - - return expression - - def _annotate_dot(self, expression: exp.Dot) -> exp.Dot: - self._annotate_args(expression) - self._set_type(expression, None) - this_type = expression.this.type - - if this_type and this_type.is_type(exp.DataType.Type.STRUCT): - for e in this_type.expressions: - if e.name == expression.expression.name: - self._set_type(expression, e.kind) - break - - return expression - - def _annotate_explode(self, expression: exp.Explode) -> exp.Explode: - self._annotate_args(expression) - self._set_type(expression, seq_get(expression.this.type.expressions, 0)) - return expression - - def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest: - self._annotate_args(expression) - child = seq_get(expression.expressions, 0) - - if child and child.is_type(exp.DataType.Type.ARRAY): - expr_type = seq_get(child.type.expressions, 0) - else: - expr_type = None - - self._set_type(expression, expr_type) - return expression - - def _annotate_struct_value( - self, expression: exp.Expression - ) -> t.Optional[exp.DataType] | exp.ColumnDef: - alias = expression.args.get("alias") - if alias: - return exp.ColumnDef(this=alias.copy(), kind=expression.type) - - # Case: key = value or key := value - if expression.expression: - return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type) - - return expression.type - - def _annotate_struct(self, expression: exp.Struct) -> exp.Struct: - self._annotate_args(expression) - self._set_type( - expression, - exp.DataType( - this=exp.DataType.Type.STRUCT, - expressions=[self._annotate_struct_value(expr) for expr in expression.expressions], - nested=True, - ), - ) - return expression - - @t.overload - def _annotate_map(self, expression: exp.Map) -> exp.Map: ... - - @t.overload - def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ... - - def _annotate_map(self, expression): - self._annotate_args(expression) - - keys = expression.args.get("keys") - values = expression.args.get("values") - - map_type = exp.DataType(this=exp.DataType.Type.MAP) - if isinstance(keys, exp.Array) and isinstance(values, exp.Array): - key_type = seq_get(keys.type.expressions, 0) or exp.DataType.Type.UNKNOWN - value_type = seq_get(values.type.expressions, 0) or exp.DataType.Type.UNKNOWN - - if key_type != exp.DataType.Type.UNKNOWN and value_type != exp.DataType.Type.UNKNOWN: - map_type.set("expressions", [key_type, value_type]) - map_type.set("nested", True) - - self._set_type(expression, map_type) - return expression - - def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap: - self._annotate_args(expression) - - map_type = exp.DataType(this=exp.DataType.Type.MAP) - arg = expression.this - if arg.is_type(exp.DataType.Type.STRUCT): - for coldef in arg.type.expressions: - kind = coldef.kind - if kind != exp.DataType.Type.UNKNOWN: - map_type.set("expressions", [exp.DataType.build("varchar"), kind]) - map_type.set("nested", True) - break - - self._set_type(expression, map_type) - return expression - - def _annotate_extract(self, expression: exp.Extract) -> exp.Extract: - self._annotate_args(expression) - part = expression.name - if part == "TIME": - self._set_type(expression, exp.DataType.Type.TIME) - elif part == "DATE": - self._set_type(expression, exp.DataType.Type.DATE) - else: - self._set_type(expression, exp.DataType.Type.INT) - return expression diff --git a/altimate_packages/sqlglot/optimizer/canonicalize.py b/altimate_packages/sqlglot/optimizer/canonicalize.py deleted file mode 100644 index d654867d9..000000000 --- a/altimate_packages/sqlglot/optimizer/canonicalize.py +++ /dev/null @@ -1,222 +0,0 @@ -from __future__ import annotations - -import itertools -import typing as t - -from sqlglot import exp -from sqlglot.dialects.dialect import Dialect, DialectType -from sqlglot.helper import is_date_unit, is_iso_date, is_iso_datetime -from sqlglot.optimizer.annotate_types import TypeAnnotator - - -def canonicalize(expression: exp.Expression, dialect: DialectType = None) -> exp.Expression: - """Converts a sql expression into a standard form. - - This method relies on annotate_types because many of the - conversions rely on type inference. - - Args: - expression: The expression to canonicalize. - """ - - dialect = Dialect.get_or_raise(dialect) - - def _canonicalize(expression: exp.Expression) -> exp.Expression: - expression = add_text_to_concat(expression) - expression = replace_date_funcs(expression, dialect=dialect) - expression = coerce_type(expression, dialect.PROMOTE_TO_INFERRED_DATETIME_TYPE) - expression = remove_redundant_casts(expression) - expression = ensure_bools(expression, _replace_int_predicate) - expression = remove_ascending_order(expression) - return expression - - return exp.replace_tree(expression, _canonicalize) - - -def add_text_to_concat(node: exp.Expression) -> exp.Expression: - if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES: - node = exp.Concat(expressions=[node.left, node.right]) - return node - - -def replace_date_funcs(node: exp.Expression, dialect: DialectType) -> exp.Expression: - if ( - isinstance(node, (exp.Date, exp.TsOrDsToDate)) - and not node.expressions - and not node.args.get("zone") - and node.this.is_string - and is_iso_date(node.this.name) - ): - return exp.cast(node.this, to=exp.DataType.Type.DATE) - if isinstance(node, exp.Timestamp) and not node.args.get("zone"): - if not node.type: - from sqlglot.optimizer.annotate_types import annotate_types - - node = annotate_types(node, dialect=dialect) - return exp.cast(node.this, to=node.type or exp.DataType.Type.TIMESTAMP) - - return node - - -COERCIBLE_DATE_OPS = ( - exp.Add, - exp.Sub, - exp.EQ, - exp.NEQ, - exp.GT, - exp.GTE, - exp.LT, - exp.LTE, - exp.NullSafeEQ, - exp.NullSafeNEQ, -) - - -def coerce_type(node: exp.Expression, promote_to_inferred_datetime_type: bool) -> exp.Expression: - if isinstance(node, COERCIBLE_DATE_OPS): - _coerce_date(node.left, node.right, promote_to_inferred_datetime_type) - elif isinstance(node, exp.Between): - _coerce_date(node.this, node.args["low"], promote_to_inferred_datetime_type) - elif isinstance(node, exp.Extract) and not node.expression.type.is_type( - *exp.DataType.TEMPORAL_TYPES - ): - _replace_cast(node.expression, exp.DataType.Type.DATETIME) - elif isinstance(node, (exp.DateAdd, exp.DateSub, exp.DateTrunc)): - _coerce_timeunit_arg(node.this, node.unit) - elif isinstance(node, exp.DateDiff): - _coerce_datediff_args(node) - - return node - - -def remove_redundant_casts(expression: exp.Expression) -> exp.Expression: - if ( - isinstance(expression, exp.Cast) - and expression.this.type - and expression.to == expression.this.type - ): - return expression.this - - if ( - isinstance(expression, (exp.Date, exp.TsOrDsToDate)) - and expression.this.type - and expression.this.type.this == exp.DataType.Type.DATE - and not expression.this.type.expressions - ): - return expression.this - - return expression - - -def ensure_bools( - expression: exp.Expression, replace_func: t.Callable[[exp.Expression], None] -) -> exp.Expression: - if isinstance(expression, exp.Connector): - replace_func(expression.left) - replace_func(expression.right) - elif isinstance(expression, exp.Not): - replace_func(expression.this) - # We can't replace num in CASE x WHEN num ..., because it's not the full predicate - elif isinstance(expression, exp.If) and not ( - isinstance(expression.parent, exp.Case) and expression.parent.this - ): - replace_func(expression.this) - elif isinstance(expression, (exp.Where, exp.Having)): - replace_func(expression.this) - - return expression - - -def remove_ascending_order(expression: exp.Expression) -> exp.Expression: - if isinstance(expression, exp.Ordered) and expression.args.get("desc") is False: - # Convert ORDER BY a ASC to ORDER BY a - expression.set("desc", None) - - return expression - - -def _coerce_date( - a: exp.Expression, - b: exp.Expression, - promote_to_inferred_datetime_type: bool, -) -> None: - for a, b in itertools.permutations([a, b]): - if isinstance(b, exp.Interval): - a = _coerce_timeunit_arg(a, b.unit) - - a_type = a.type - if ( - not a_type - or a_type.this not in exp.DataType.TEMPORAL_TYPES - or not b.type - or b.type.this not in exp.DataType.TEXT_TYPES - ): - continue - - if promote_to_inferred_datetime_type: - if b.is_string: - date_text = b.name - if is_iso_date(date_text): - b_type = exp.DataType.Type.DATE - elif is_iso_datetime(date_text): - b_type = exp.DataType.Type.DATETIME - else: - b_type = a_type.this - else: - # If b is not a datetime string, we conservatively promote it to a DATETIME, - # in order to ensure there are no surprising truncations due to downcasting - b_type = exp.DataType.Type.DATETIME - - target_type = ( - b_type if b_type in TypeAnnotator.COERCES_TO.get(a_type.this, {}) else a_type - ) - else: - target_type = a_type - - if target_type != a_type: - _replace_cast(a, target_type) - - _replace_cast(b, target_type) - - -def _coerce_timeunit_arg(arg: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.Expression: - if not arg.type: - return arg - - if arg.type.this in exp.DataType.TEXT_TYPES: - date_text = arg.name - is_iso_date_ = is_iso_date(date_text) - - if is_iso_date_ and is_date_unit(unit): - return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATE)) - - # An ISO date is also an ISO datetime, but not vice versa - if is_iso_date_ or is_iso_datetime(date_text): - return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) - - elif arg.type.this == exp.DataType.Type.DATE and not is_date_unit(unit): - return arg.replace(exp.cast(arg.copy(), to=exp.DataType.Type.DATETIME)) - - return arg - - -def _coerce_datediff_args(node: exp.DateDiff) -> None: - for e in (node.this, node.expression): - if e.type.this not in exp.DataType.TEMPORAL_TYPES: - e.replace(exp.cast(e.copy(), to=exp.DataType.Type.DATETIME)) - - -def _replace_cast(node: exp.Expression, to: exp.DATA_TYPE) -> None: - node.replace(exp.cast(node.copy(), to=to)) - - -# this was originally designed for presto, there is a similar transform for tsql -# this is different in that it only operates on int types, this is because -# presto has a boolean type whereas tsql doesn't (people use bits) -# with y as (select true as x) select x = 0 FROM y -- illegal presto query -def _replace_int_predicate(expression: exp.Expression) -> None: - if isinstance(expression, exp.Coalesce): - for child in expression.iter_expressions(): - _replace_int_predicate(child) - elif expression.type and expression.type.this in exp.DataType.INTEGER_TYPES: - expression.replace(expression.neq(0)) diff --git a/altimate_packages/sqlglot/optimizer/eliminate_ctes.py b/altimate_packages/sqlglot/optimizer/eliminate_ctes.py deleted file mode 100644 index d2e876cd0..000000000 --- a/altimate_packages/sqlglot/optimizer/eliminate_ctes.py +++ /dev/null @@ -1,43 +0,0 @@ -from sqlglot.optimizer.scope import Scope, build_scope - - -def eliminate_ctes(expression): - """ - Remove unused CTEs from an expression. - - Example: - >>> import sqlglot - >>> sql = "WITH y AS (SELECT a FROM x) SELECT a FROM z" - >>> expression = sqlglot.parse_one(sql) - >>> eliminate_ctes(expression).sql() - 'SELECT a FROM z' - - Args: - expression (sqlglot.Expression): expression to optimize - Returns: - sqlglot.Expression: optimized expression - """ - root = build_scope(expression) - - if root: - ref_count = root.ref_count() - - # Traverse the scope tree in reverse so we can remove chains of unused CTEs - for scope in reversed(list(root.traverse())): - if scope.is_cte: - count = ref_count[id(scope)] - if count <= 0: - cte_node = scope.expression.parent - with_node = cte_node.parent - cte_node.pop() - - # Pop the entire WITH clause if this is the last CTE - if with_node and len(with_node.expressions) <= 0: - with_node.pop() - - # Decrement the ref count for all sources this CTE selects from - for _, source in scope.selected_sources.values(): - if isinstance(source, Scope): - ref_count[id(source)] -= 1 - - return expression diff --git a/altimate_packages/sqlglot/optimizer/eliminate_joins.py b/altimate_packages/sqlglot/optimizer/eliminate_joins.py deleted file mode 100644 index 3134e6598..000000000 --- a/altimate_packages/sqlglot/optimizer/eliminate_joins.py +++ /dev/null @@ -1,181 +0,0 @@ -from sqlglot import expressions as exp -from sqlglot.optimizer.normalize import normalized -from sqlglot.optimizer.scope import Scope, traverse_scope - - -def eliminate_joins(expression): - """ - Remove unused joins from an expression. - - This only removes joins when we know that the join condition doesn't produce duplicate rows. - - Example: - >>> import sqlglot - >>> sql = "SELECT x.a FROM x LEFT JOIN (SELECT DISTINCT y.b FROM y) AS y ON x.b = y.b" - >>> expression = sqlglot.parse_one(sql) - >>> eliminate_joins(expression).sql() - 'SELECT x.a FROM x' - - Args: - expression (sqlglot.Expression): expression to optimize - Returns: - sqlglot.Expression: optimized expression - """ - for scope in traverse_scope(expression): - # If any columns in this scope aren't qualified, it's hard to determine if a join isn't used. - # It's probably possible to infer this from the outputs of derived tables. - # But for now, let's just skip this rule. - if scope.unqualified_columns: - continue - - joins = scope.expression.args.get("joins", []) - - # Reverse the joins so we can remove chains of unused joins - for join in reversed(joins): - alias = join.alias_or_name - if _should_eliminate_join(scope, join, alias): - join.pop() - scope.remove_source(alias) - return expression - - -def _should_eliminate_join(scope, join, alias): - inner_source = scope.sources.get(alias) - return ( - isinstance(inner_source, Scope) - and not _join_is_used(scope, join, alias) - and ( - (join.side == "LEFT" and _is_joined_on_all_unique_outputs(inner_source, join)) - or (not join.args.get("on") and _has_single_output_row(inner_source)) - ) - ) - - -def _join_is_used(scope, join, alias): - # We need to find all columns that reference this join. - # But columns in the ON clause shouldn't count. - on = join.args.get("on") - if on: - on_clause_columns = {id(column) for column in on.find_all(exp.Column)} - else: - on_clause_columns = set() - return any( - column for column in scope.source_columns(alias) if id(column) not in on_clause_columns - ) - - -def _is_joined_on_all_unique_outputs(scope, join): - unique_outputs = _unique_outputs(scope) - if not unique_outputs: - return False - - _, join_keys, _ = join_condition(join) - remaining_unique_outputs = unique_outputs - {c.name for c in join_keys} - return not remaining_unique_outputs - - -def _unique_outputs(scope): - """Determine output columns of `scope` that must have a unique combination per row""" - if scope.expression.args.get("distinct"): - return set(scope.expression.named_selects) - - group = scope.expression.args.get("group") - if group: - grouped_expressions = set(group.expressions) - grouped_outputs = set() - - unique_outputs = set() - for select in scope.expression.selects: - output = select.unalias() - if output in grouped_expressions: - grouped_outputs.add(output) - unique_outputs.add(select.alias_or_name) - - # All the grouped expressions must be in the output - if not grouped_expressions.difference(grouped_outputs): - return unique_outputs - else: - return set() - - if _has_single_output_row(scope): - return set(scope.expression.named_selects) - - return set() - - -def _has_single_output_row(scope): - return isinstance(scope.expression, exp.Select) and ( - all(isinstance(e.unalias(), exp.AggFunc) for e in scope.expression.selects) - or _is_limit_1(scope) - or not scope.expression.args.get("from") - ) - - -def _is_limit_1(scope): - limit = scope.expression.args.get("limit") - return limit and limit.expression.this == "1" - - -def join_condition(join): - """ - Extract the join condition from a join expression. - - Args: - join (exp.Join) - Returns: - tuple[list[str], list[str], exp.Expression]: - Tuple of (source key, join key, remaining predicate) - """ - name = join.alias_or_name - on = (join.args.get("on") or exp.true()).copy() - source_key = [] - join_key = [] - - def extract_condition(condition): - left, right = condition.unnest_operands() - left_tables = exp.column_table_names(left) - right_tables = exp.column_table_names(right) - - if name in left_tables and name not in right_tables: - join_key.append(left) - source_key.append(right) - condition.replace(exp.true()) - elif name in right_tables and name not in left_tables: - join_key.append(right) - source_key.append(left) - condition.replace(exp.true()) - - # find the join keys - # SELECT - # FROM x - # JOIN y - # ON x.a = y.b AND y.b > 1 - # - # should pull y.b as the join key and x.a as the source key - if normalized(on): - on = on if isinstance(on, exp.And) else exp.and_(on, exp.true(), copy=False) - - for condition in on.flatten(): - if isinstance(condition, exp.EQ): - extract_condition(condition) - elif normalized(on, dnf=True): - conditions = None - - for condition in on.flatten(): - parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] - if conditions is None: - conditions = parts - else: - temp = [] - for p in parts: - cs = [c for c in conditions if p == c] - - if cs: - temp.append(p) - temp.extend(cs) - conditions = temp - - for condition in conditions: - extract_condition(condition) - - return source_key, join_key, on diff --git a/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py b/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py deleted file mode 100644 index b66100369..000000000 --- a/altimate_packages/sqlglot/optimizer/eliminate_subqueries.py +++ /dev/null @@ -1,189 +0,0 @@ -from __future__ import annotations - -import itertools -import typing as t - -from sqlglot import expressions as exp -from sqlglot.helper import find_new_name -from sqlglot.optimizer.scope import Scope, build_scope - -if t.TYPE_CHECKING: - ExistingCTEsMapping = t.Dict[exp.Expression, str] - TakenNameMapping = t.Dict[str, t.Union[Scope, exp.Expression]] - - -def eliminate_subqueries(expression: exp.Expression) -> exp.Expression: - """ - Rewrite derived tables as CTES, deduplicating if possible. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y") - >>> eliminate_subqueries(expression).sql() - 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y' - - This also deduplicates common subqueries: - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT * FROM x) AS y CROSS JOIN (SELECT * FROM x) AS z") - >>> eliminate_subqueries(expression).sql() - 'WITH y AS (SELECT * FROM x) SELECT a FROM y AS y CROSS JOIN y AS z' - - Args: - expression (sqlglot.Expression): expression - Returns: - sqlglot.Expression: expression - """ - if isinstance(expression, exp.Subquery): - # It's possible to have subqueries at the root, e.g. (SELECT * FROM x) LIMIT 1 - eliminate_subqueries(expression.this) - return expression - - root = build_scope(expression) - - if not root: - return expression - - # Map of alias->Scope|Table - # These are all aliases that are already used in the expression. - # We don't want to create new CTEs that conflict with these names. - taken: TakenNameMapping = {} - - # All CTE aliases in the root scope are taken - for scope in root.cte_scopes: - taken[scope.expression.parent.alias] = scope - - # All table names are taken - for scope in root.traverse(): - taken.update( - { - source.name: source - for _, source in scope.sources.items() - if isinstance(source, exp.Table) - } - ) - - # Map of Expression->alias - # Existing CTES in the root expression. We'll use this for deduplication. - existing_ctes: ExistingCTEsMapping = {} - - with_ = root.expression.args.get("with") - recursive = False - if with_: - recursive = with_.args.get("recursive") - for cte in with_.expressions: - existing_ctes[cte.this] = cte.alias - new_ctes = [] - - # We're adding more CTEs, but we want to maintain the DAG order. - # Derived tables within an existing CTE need to come before the existing CTE. - for cte_scope in root.cte_scopes: - # Append all the new CTEs from this existing CTE - for scope in cte_scope.traverse(): - if scope is cte_scope: - # Don't try to eliminate this CTE itself - continue - new_cte = _eliminate(scope, existing_ctes, taken) - if new_cte: - new_ctes.append(new_cte) - - # Append the existing CTE itself - new_ctes.append(cte_scope.expression.parent) - - # Now append the rest - for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.table_scopes): - for child_scope in scope.traverse(): - new_cte = _eliminate(child_scope, existing_ctes, taken) - if new_cte: - new_ctes.append(new_cte) - - if new_ctes: - query = expression.expression if isinstance(expression, exp.DDL) else expression - query.set("with", exp.With(expressions=new_ctes, recursive=recursive)) - - return expression - - -def _eliminate( - scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping -) -> t.Optional[exp.Expression]: - if scope.is_derived_table: - return _eliminate_derived_table(scope, existing_ctes, taken) - - if scope.is_cte: - return _eliminate_cte(scope, existing_ctes, taken) - - return None - - -def _eliminate_derived_table( - scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping -) -> t.Optional[exp.Expression]: - # This makes sure that we don't: - # - drop the "pivot" arg from a pivoted subquery - # - eliminate a lateral correlated subquery - if scope.parent.pivots or isinstance(scope.parent.expression, exp.Lateral): - return None - - # Get rid of redundant exp.Subquery expressions, i.e. those that are just used as wrappers - to_replace = scope.expression.parent.unwrap() - name, cte = _new_cte(scope, existing_ctes, taken) - table = exp.alias_(exp.table_(name), alias=to_replace.alias or name) - table.set("joins", to_replace.args.get("joins")) - - to_replace.replace(table) - - return cte - - -def _eliminate_cte( - scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping -) -> t.Optional[exp.Expression]: - parent = scope.expression.parent - name, cte = _new_cte(scope, existing_ctes, taken) - - with_ = parent.parent - parent.pop() - if not with_.expressions: - with_.pop() - - # Rename references to this CTE - for child_scope in scope.parent.traverse(): - for table, source in child_scope.selected_sources.values(): - if source is scope: - new_table = exp.alias_(exp.table_(name), alias=table.alias_or_name, copy=False) - table.replace(new_table) - - return cte - - -def _new_cte( - scope: Scope, existing_ctes: ExistingCTEsMapping, taken: TakenNameMapping -) -> t.Tuple[str, t.Optional[exp.Expression]]: - """ - Returns: - tuple of (name, cte) - where `name` is a new name for this CTE in the root scope and `cte` is a new CTE instance. - If this CTE duplicates an existing CTE, `cte` will be None. - """ - duplicate_cte_alias = existing_ctes.get(scope.expression) - parent = scope.expression.parent - name = parent.alias - - if not name: - name = find_new_name(taken=taken, base="cte") - - if duplicate_cte_alias: - name = duplicate_cte_alias - elif taken.get(name): - name = find_new_name(taken=taken, base=name) - - taken[name] = scope - - if not duplicate_cte_alias: - existing_ctes[scope.expression] = name - cte = exp.CTE( - this=scope.expression, - alias=exp.TableAlias(this=exp.to_identifier(name)), - ) - else: - cte = None - return name, cte diff --git a/altimate_packages/sqlglot/optimizer/isolate_table_selects.py b/altimate_packages/sqlglot/optimizer/isolate_table_selects.py deleted file mode 100644 index e0c1d9c1d..000000000 --- a/altimate_packages/sqlglot/optimizer/isolate_table_selects.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import alias, exp -from sqlglot.errors import OptimizeError -from sqlglot.optimizer.scope import traverse_scope -from sqlglot.schema import ensure_schema - -if t.TYPE_CHECKING: - from sqlglot._typing import E - from sqlglot.schema import Schema - from sqlglot.dialects.dialect import DialectType - - -def isolate_table_selects( - expression: E, - schema: t.Optional[t.Dict | Schema] = None, - dialect: DialectType = None, -) -> E: - schema = ensure_schema(schema, dialect=dialect) - - for scope in traverse_scope(expression): - if len(scope.selected_sources) == 1: - continue - - for _, source in scope.selected_sources.values(): - assert source.parent - - if ( - not isinstance(source, exp.Table) - or not schema.column_names(source) - or isinstance(source.parent, exp.Subquery) - or isinstance(source.parent.parent, exp.Table) - ): - continue - - if not source.alias: - raise OptimizeError("Tables require an alias. Run qualify_tables optimization.") - - source.replace( - exp.select("*") - .from_( - alias(source, source.alias_or_name, table=True), - copy=False, - ) - .subquery(source.alias, copy=False) - ) - - return expression diff --git a/altimate_packages/sqlglot/optimizer/merge_subqueries.py b/altimate_packages/sqlglot/optimizer/merge_subqueries.py deleted file mode 100644 index 358a21857..000000000 --- a/altimate_packages/sqlglot/optimizer/merge_subqueries.py +++ /dev/null @@ -1,415 +0,0 @@ -from __future__ import annotations - -import typing as t - -from collections import defaultdict - -from sqlglot import expressions as exp -from sqlglot.helper import find_new_name, seq_get -from sqlglot.optimizer.scope import Scope, traverse_scope - -if t.TYPE_CHECKING: - from sqlglot._typing import E - - FromOrJoin = t.Union[exp.From, exp.Join] - - -def merge_subqueries(expression: E, leave_tables_isolated: bool = False) -> E: - """ - Rewrite sqlglot AST to merge derived tables into the outer query. - - This also merges CTEs if they are selected from only once. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") - >>> merge_subqueries(expression).sql() - 'SELECT x.a FROM x CROSS JOIN y' - - If `leave_tables_isolated` is True, this will not merge inner queries into outer - queries if it would result in multiple table selects in a single query: - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y") - >>> merge_subqueries(expression, leave_tables_isolated=True).sql() - 'SELECT a FROM (SELECT x.a FROM x) CROSS JOIN y' - - Inspired by https://dev.mysql.com/doc/refman/8.0/en/derived-table-optimization.html - - Args: - expression (sqlglot.Expression): expression to optimize - leave_tables_isolated (bool): - Returns: - sqlglot.Expression: optimized expression - """ - expression = merge_ctes(expression, leave_tables_isolated) - expression = merge_derived_tables(expression, leave_tables_isolated) - return expression - - -# If a derived table has these Select args, it can't be merged -UNMERGABLE_ARGS = set(exp.Select.arg_types) - { - "expressions", - "from", - "joins", - "where", - "order", - "hint", -} - - -# Projections in the outer query that are instances of these types can be replaced -# without getting wrapped in parentheses, because the precedence won't be altered. -SAFE_TO_REPLACE_UNWRAPPED = ( - exp.Column, - exp.EQ, - exp.Func, - exp.NEQ, - exp.Paren, -) - - -def merge_ctes(expression: E, leave_tables_isolated: bool = False) -> E: - scopes = traverse_scope(expression) - - # All places where we select from CTEs. - # We key on the CTE scope so we can detect CTES that are selected from multiple times. - cte_selections = defaultdict(list) - for outer_scope in scopes: - for table, inner_scope in outer_scope.selected_sources.values(): - if isinstance(inner_scope, Scope) and inner_scope.is_cte: - cte_selections[id(inner_scope)].append( - ( - outer_scope, - inner_scope, - table, - ) - ) - - singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] - for outer_scope, inner_scope, table in singular_cte_selections: - from_or_join = table.find_ancestor(exp.From, exp.Join) - if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): - alias = table.alias_or_name - _rename_inner_sources(outer_scope, inner_scope, alias) - _merge_from(outer_scope, inner_scope, table, alias) - _merge_expressions(outer_scope, inner_scope, alias) - _merge_order(outer_scope, inner_scope) - _merge_joins(outer_scope, inner_scope, from_or_join) - _merge_where(outer_scope, inner_scope, from_or_join) - _merge_hints(outer_scope, inner_scope) - _pop_cte(inner_scope) - outer_scope.clear_cache() - return expression - - -def merge_derived_tables(expression: E, leave_tables_isolated: bool = False) -> E: - for outer_scope in traverse_scope(expression): - for subquery in outer_scope.derived_tables: - from_or_join = subquery.find_ancestor(exp.From, exp.Join) - alias = subquery.alias_or_name - inner_scope = outer_scope.sources[alias] - if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): - _rename_inner_sources(outer_scope, inner_scope, alias) - _merge_from(outer_scope, inner_scope, subquery, alias) - _merge_expressions(outer_scope, inner_scope, alias) - _merge_order(outer_scope, inner_scope) - _merge_joins(outer_scope, inner_scope, from_or_join) - _merge_where(outer_scope, inner_scope, from_or_join) - _merge_hints(outer_scope, inner_scope) - outer_scope.clear_cache() - - return expression - - -def _mergeable( - outer_scope: Scope, inner_scope: Scope, leave_tables_isolated: bool, from_or_join: FromOrJoin -) -> bool: - """ - Return True if `inner_select` can be merged into outer query. - """ - inner_select = inner_scope.expression.unnest() - - def _is_a_window_expression_in_unmergable_operation(): - window_aliases = {s.alias_or_name for s in inner_select.selects if s.find(exp.Window)} - inner_select_name = from_or_join.alias_or_name - unmergable_window_columns = [ - column - for column in outer_scope.columns - if column.find_ancestor( - exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc - ) - ] - window_expressions_in_unmergable = [ - column - for column in unmergable_window_columns - if column.table == inner_select_name and column.name in window_aliases - ] - return any(window_expressions_in_unmergable) - - def _outer_select_joins_on_inner_select_join(): - """ - All columns from the inner select in the ON clause must be from the first FROM table. - - That is, this can be merged: - SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a - ^^^ ^ - But this can't: - SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a - ^^^ ^ - """ - if not isinstance(from_or_join, exp.Join): - return False - - alias = from_or_join.alias_or_name - - on = from_or_join.args.get("on") - if not on: - return False - selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] - inner_from = inner_scope.expression.args.get("from") - if not inner_from: - return False - inner_from_table = inner_from.alias_or_name - inner_projections = {s.alias_or_name: s for s in inner_scope.expression.selects} - return any( - col.table != inner_from_table - for selection in selections - for col in inner_projections[selection].find_all(exp.Column) - ) - - def _is_recursive(): - # Recursive CTEs look like this: - # WITH RECURSIVE cte AS ( - # SELECT * FROM x <-- inner scope - # UNION ALL - # SELECT * FROM cte <-- outer scope - # ) - cte = inner_scope.expression.parent - node = outer_scope.expression.parent - - while node: - if node is cte: - return True - node = node.parent - return False - - return ( - isinstance(outer_scope.expression, exp.Select) - and not outer_scope.expression.is_star - and isinstance(inner_select, exp.Select) - and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) - and inner_select.args.get("from") is not None - and not outer_scope.pivots - and not any(e.find(exp.AggFunc, exp.Select, exp.Explode) for e in inner_select.expressions) - and not (leave_tables_isolated and len(outer_scope.selected_sources) > 1) - and not ( - isinstance(from_or_join, exp.Join) - and inner_select.args.get("where") - and from_or_join.side in ("FULL", "LEFT", "RIGHT") - ) - and not ( - isinstance(from_or_join, exp.From) - and inner_select.args.get("where") - and any( - j.side in ("FULL", "RIGHT") for j in outer_scope.expression.args.get("joins", []) - ) - ) - and not _outer_select_joins_on_inner_select_join() - and not _is_a_window_expression_in_unmergable_operation() - and not _is_recursive() - and not (inner_select.args.get("order") and outer_scope.is_union) - and not isinstance(seq_get(inner_select.expressions, 0), exp.QueryTransform) - ) - - -def _rename_inner_sources(outer_scope: Scope, inner_scope: Scope, alias: str) -> None: - """ - Renames any sources in the inner query that conflict with names in the outer query. - """ - inner_taken = set(inner_scope.selected_sources) - outer_taken = set(outer_scope.selected_sources) - conflicts = outer_taken.intersection(inner_taken) - conflicts -= {alias} - - taken = outer_taken.union(inner_taken) - - for conflict in conflicts: - new_name = find_new_name(taken, conflict) - - source, _ = inner_scope.selected_sources[conflict] - new_alias = exp.to_identifier(new_name) - - if isinstance(source, exp.Table) and source.alias: - source.set("alias", new_alias) - elif isinstance(source, exp.Table): - source.replace(exp.alias_(source, new_alias)) - elif isinstance(source.parent, exp.Subquery): - source.parent.set("alias", exp.TableAlias(this=new_alias)) - - for column in inner_scope.source_columns(conflict): - column.set("table", exp.to_identifier(new_name)) - - inner_scope.rename_source(conflict, new_name) - - -def _merge_from( - outer_scope: Scope, - inner_scope: Scope, - node_to_replace: t.Union[exp.Subquery, exp.Table], - alias: str, -) -> None: - """ - Merge FROM clause of inner query into outer query. - """ - new_subquery = inner_scope.expression.args["from"].this - new_subquery.set("joins", node_to_replace.args.get("joins")) - node_to_replace.replace(new_subquery) - for join_hint in outer_scope.join_hints: - tables = join_hint.find_all(exp.Table) - for table in tables: - if table.alias_or_name == node_to_replace.alias_or_name: - table.set("this", exp.to_identifier(new_subquery.alias_or_name)) - outer_scope.remove_source(alias) - outer_scope.add_source( - new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name] - ) - - -def _merge_joins(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None: - """ - Merge JOIN clauses of inner query into outer query. - """ - - new_joins = [] - - joins = inner_scope.expression.args.get("joins") or [] - for join in joins: - new_joins.append(join) - outer_scope.add_source(join.alias_or_name, inner_scope.sources[join.alias_or_name]) - - if new_joins: - outer_joins = outer_scope.expression.args.get("joins", []) - - # Maintain the join order - if isinstance(from_or_join, exp.From): - position = 0 - else: - position = outer_joins.index(from_or_join) + 1 - outer_joins[position:position] = new_joins - - outer_scope.expression.set("joins", outer_joins) - - -def _merge_expressions(outer_scope: Scope, inner_scope: Scope, alias: str) -> None: - """ - Merge projections of inner query into outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - alias (str) - """ - # Collect all columns that reference the alias of the inner query - outer_columns = defaultdict(list) - for column in outer_scope.columns: - if column.table == alias: - outer_columns[column.name].append(column) - - # Replace columns with the projection expression in the inner query - for expression in inner_scope.expression.expressions: - projection_name = expression.alias_or_name - if not projection_name: - continue - columns_to_replace = outer_columns.get(projection_name, []) - - expression = expression.unalias() - must_wrap_expression = not isinstance(expression, SAFE_TO_REPLACE_UNWRAPPED) - - for column in columns_to_replace: - # Ensures we don't alter the intended operator precedence if there's additional - # context surrounding the outer expression (i.e. it's not a simple projection). - if isinstance(column.parent, (exp.Unary, exp.Binary)) and must_wrap_expression: - expression = exp.paren(expression, copy=False) - - column.replace(expression.copy()) - - -def _merge_where(outer_scope: Scope, inner_scope: Scope, from_or_join: FromOrJoin) -> None: - """ - Merge WHERE clause of inner query into outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - from_or_join (exp.From|exp.Join) - """ - where = inner_scope.expression.args.get("where") - if not where or not where.this: - return - - expression = outer_scope.expression - - if isinstance(from_or_join, exp.Join): - # Merge predicates from an outer join to the ON clause - # if it only has columns that are already joined - from_ = expression.args.get("from") - sources = {from_.alias_or_name} if from_ else set() - - for join in expression.args["joins"]: - source = join.alias_or_name - sources.add(source) - if source == from_or_join.alias_or_name: - break - - if exp.column_table_names(where.this) <= sources: - from_or_join.on(where.this, copy=False) - from_or_join.set("on", from_or_join.args.get("on")) - return - - expression.where(where.this, copy=False) - - -def _merge_order(outer_scope: Scope, inner_scope: Scope) -> None: - """ - Merge ORDER clause of inner query into outer query. - - Args: - outer_scope (sqlglot.optimizer.scope.Scope) - inner_scope (sqlglot.optimizer.scope.Scope) - """ - if ( - any( - outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"] - ) - or len(outer_scope.selected_sources) != 1 - or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions) - ): - return - - outer_scope.expression.set("order", inner_scope.expression.args.get("order")) - - -def _merge_hints(outer_scope: Scope, inner_scope: Scope) -> None: - inner_scope_hint = inner_scope.expression.args.get("hint") - if not inner_scope_hint: - return - outer_scope_hint = outer_scope.expression.args.get("hint") - if outer_scope_hint: - for hint_expression in inner_scope_hint.expressions: - outer_scope_hint.append("expressions", hint_expression) - else: - outer_scope.expression.set("hint", inner_scope_hint) - - -def _pop_cte(inner_scope: Scope) -> None: - """ - Remove CTE from the AST. - - Args: - inner_scope (sqlglot.optimizer.scope.Scope) - """ - cte = inner_scope.expression.parent - with_ = cte.parent - if len(with_.expressions) == 1: - with_.pop() - else: - cte.pop() diff --git a/altimate_packages/sqlglot/optimizer/normalize.py b/altimate_packages/sqlglot/optimizer/normalize.py deleted file mode 100644 index 610833d4a..000000000 --- a/altimate_packages/sqlglot/optimizer/normalize.py +++ /dev/null @@ -1,200 +0,0 @@ -from __future__ import annotations - -import logging - -from sqlglot import exp -from sqlglot.errors import OptimizeError -from sqlglot.helper import while_changing -from sqlglot.optimizer.scope import find_all_in_scope -from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort - -logger = logging.getLogger("sqlglot") - - -def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): - """ - Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one("(x AND y) OR z") - >>> normalize(expression, dnf=False).sql() - '(x OR z) AND (y OR z)' - - Args: - expression: expression to normalize - dnf: rewrite in disjunctive normal form instead. - max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion - Returns: - sqlglot.Expression: normalized expression - """ - for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): - if isinstance(node, exp.Connector): - if normalized(node, dnf=dnf): - continue - root = node is expression - original = node.copy() - - node.transform(rewrite_between, copy=False) - distance = normalization_distance(node, dnf=dnf, max_=max_distance) - - if distance > max_distance: - logger.info( - f"Skipping normalization because distance {distance} exceeds max {max_distance}" - ) - return expression - - try: - node = node.replace( - while_changing(node, lambda e: distributive_law(e, dnf, max_distance)) - ) - except OptimizeError as e: - logger.info(e) - node.replace(original) - if root: - return original - return expression - - if root: - expression = node - - return expression - - -def normalized(expression: exp.Expression, dnf: bool = False) -> bool: - """ - Checks whether a given expression is in a normal form of interest. - - Example: - >>> from sqlglot import parse_one - >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) - True - >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default - True - >>> normalized(parse_one("a AND (b OR c)"), dnf=True) - False - - Args: - expression: The expression to check if it's normalized. - dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). - Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). - """ - ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) - return not any( - connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) - ) - - -def normalization_distance( - expression: exp.Expression, dnf: bool = False, max_: float = float("inf") -) -> int: - """ - The difference in the number of predicates between a given expression and its normalized form. - - This is used as an estimate of the cost of the conversion which is exponential in complexity. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") - >>> normalization_distance(expression) - 4 - - Args: - expression: The expression to compute the normalization distance for. - dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). - Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). - max_: stop early if count exceeds this. - - Returns: - The normalization distance. - """ - total = -(sum(1 for _ in expression.find_all(exp.Connector)) + 1) - - for length in _predicate_lengths(expression, dnf, max_): - total += length - if total > max_: - return total - - return total - - -def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0): - """ - Returns a list of predicate lengths when expanded to normalized form. - - (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). - """ - if depth > max_: - yield depth - return - - expression = expression.unnest() - - if not isinstance(expression, exp.Connector): - yield 1 - return - - depth += 1 - left, right = expression.args.values() - - if isinstance(expression, exp.And if dnf else exp.Or): - for a in _predicate_lengths(left, dnf, max_, depth): - for b in _predicate_lengths(right, dnf, max_, depth): - yield a + b - else: - yield from _predicate_lengths(left, dnf, max_, depth) - yield from _predicate_lengths(right, dnf, max_, depth) - - -def distributive_law(expression, dnf, max_distance): - """ - x OR (y AND z) -> (x OR y) AND (x OR z) - (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) - """ - if normalized(expression, dnf=dnf): - return expression - - distance = normalization_distance(expression, dnf=dnf, max_=max_distance) - - if distance > max_distance: - raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") - - exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) - to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) - - if isinstance(expression, from_exp): - a, b = expression.unnest_operands() - - from_func = exp.and_ if from_exp == exp.And else exp.or_ - to_func = exp.and_ if to_exp == exp.And else exp.or_ - - if isinstance(a, to_exp) and isinstance(b, to_exp): - if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func) - return _distribute(b, a, from_func, to_func) - if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func) - if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func) - - return expression - - -def _distribute(a, b, from_func, to_func): - if isinstance(a, exp.Connector): - exp.replace_children( - a, - lambda c: to_func( - uniq_sort(flatten(from_func(c, b.left))), - uniq_sort(flatten(from_func(c, b.right))), - copy=False, - ), - ) - else: - a = to_func( - uniq_sort(flatten(from_func(a, b.left))), - uniq_sort(flatten(from_func(a, b.right))), - copy=False, - ) - - return a diff --git a/altimate_packages/sqlglot/optimizer/normalize_identifiers.py b/altimate_packages/sqlglot/optimizer/normalize_identifiers.py deleted file mode 100644 index dd421d9bb..000000000 --- a/altimate_packages/sqlglot/optimizer/normalize_identifiers.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp -from sqlglot.dialects.dialect import Dialect, DialectType - -if t.TYPE_CHECKING: - from sqlglot._typing import E - - -@t.overload -def normalize_identifiers(expression: E, dialect: DialectType = None) -> E: ... - - -@t.overload -def normalize_identifiers(expression: str, dialect: DialectType = None) -> exp.Identifier: ... - - -def normalize_identifiers(expression, dialect=None): - """ - Normalize identifiers by converting them to either lower or upper case, - ensuring the semantics are preserved in each case (e.g. by respecting - case-sensitivity). - - This transformation reflects how identifiers would be resolved by the engine corresponding - to each SQL dialect, and plays a very important role in the standardization of the AST. - - It's possible to make this a no-op by adding a special comment next to the - identifier of interest: - - SELECT a /* sqlglot.meta case_sensitive */ FROM table - - In this example, the identifier `a` will not be normalized. - - Note: - Some dialects (e.g. DuckDB) treat all identifiers as case-insensitive even - when they're quoted, so in these cases all identifiers are normalized. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one('SELECT Bar.A AS A FROM "Foo".Bar') - >>> normalize_identifiers(expression).sql() - 'SELECT bar.a AS a FROM "Foo".bar' - >>> normalize_identifiers("foo", dialect="snowflake").sql(dialect="snowflake") - 'FOO' - - Args: - expression: The expression to transform. - dialect: The dialect to use in order to decide how to normalize identifiers. - - Returns: - The transformed expression. - """ - dialect = Dialect.get_or_raise(dialect) - - if isinstance(expression, str): - expression = exp.parse_identifier(expression, dialect=dialect) - - for node in expression.walk(prune=lambda n: n.meta.get("case_sensitive")): - if not node.meta.get("case_sensitive"): - dialect.normalize_identifier(node) - - return expression diff --git a/altimate_packages/sqlglot/optimizer/optimize_joins.py b/altimate_packages/sqlglot/optimizer/optimize_joins.py deleted file mode 100644 index 15304561a..000000000 --- a/altimate_packages/sqlglot/optimizer/optimize_joins.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp -from sqlglot.helper import tsort - -JOIN_ATTRS = ("on", "side", "kind", "using", "method") - - -def optimize_joins(expression): - """ - Removes cross joins if possible and reorder joins based on predicate dependencies. - - Example: - >>> from sqlglot import parse_one - >>> optimize_joins(parse_one("SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a")).sql() - 'SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a' - """ - - for select in expression.find_all(exp.Select): - references = {} - cross_joins = [] - - for join in select.args.get("joins", []): - tables = other_table_names(join) - - if tables: - for table in tables: - references[table] = references.get(table, []) + [join] - else: - cross_joins.append((join.alias_or_name, join)) - - for name, join in cross_joins: - for dep in references.get(name, []): - on = dep.args["on"] - - if isinstance(on, exp.Connector): - if len(other_table_names(dep)) < 2: - continue - - operator = type(on) - for predicate in on.flatten(): - if name in exp.column_table_names(predicate): - predicate.replace(exp.true()) - predicate = exp._combine( - [join.args.get("on"), predicate], operator, copy=False - ) - join.on(predicate, append=False, copy=False) - - expression = reorder_joins(expression) - expression = normalize(expression) - return expression - - -def reorder_joins(expression): - """ - Reorder joins by topological sort order based on predicate references. - """ - for from_ in expression.find_all(exp.From): - parent = from_.parent - joins = {join.alias_or_name: join for join in parent.args.get("joins", [])} - dag = {name: other_table_names(join) for name, join in joins.items()} - parent.set( - "joins", - [joins[name] for name in tsort(dag) if name != from_.alias_or_name and name in joins], - ) - return expression - - -def normalize(expression): - """ - Remove INNER and OUTER from joins as they are optional. - """ - for join in expression.find_all(exp.Join): - if not any(join.args.get(k) for k in JOIN_ATTRS): - join.set("kind", "CROSS") - - if join.kind == "CROSS": - join.set("on", None) - else: - join.set("kind", None) - - if not join.args.get("on") and not join.args.get("using"): - join.set("on", exp.true()) - return expression - - -def other_table_names(join: exp.Join) -> t.Set[str]: - on = join.args.get("on") - return exp.column_table_names(on, join.alias_or_name) if on else set() diff --git a/altimate_packages/sqlglot/optimizer/optimizer.py b/altimate_packages/sqlglot/optimizer/optimizer.py deleted file mode 100644 index 0a0e13df2..000000000 --- a/altimate_packages/sqlglot/optimizer/optimizer.py +++ /dev/null @@ -1,94 +0,0 @@ -from __future__ import annotations - -import inspect -import typing as t - -from sqlglot import Schema, exp -from sqlglot.dialects.dialect import DialectType -from sqlglot.optimizer.annotate_types import annotate_types -from sqlglot.optimizer.canonicalize import canonicalize -from sqlglot.optimizer.eliminate_ctes import eliminate_ctes -from sqlglot.optimizer.eliminate_joins import eliminate_joins -from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries -from sqlglot.optimizer.merge_subqueries import merge_subqueries -from sqlglot.optimizer.normalize import normalize -from sqlglot.optimizer.optimize_joins import optimize_joins -from sqlglot.optimizer.pushdown_predicates import pushdown_predicates -from sqlglot.optimizer.pushdown_projections import pushdown_projections -from sqlglot.optimizer.qualify import qualify -from sqlglot.optimizer.qualify_columns import quote_identifiers -from sqlglot.optimizer.simplify import simplify -from sqlglot.optimizer.unnest_subqueries import unnest_subqueries -from sqlglot.schema import ensure_schema - -RULES = ( - qualify, - pushdown_projections, - normalize, - unnest_subqueries, - pushdown_predicates, - optimize_joins, - eliminate_subqueries, - merge_subqueries, - eliminate_joins, - eliminate_ctes, - quote_identifiers, - annotate_types, - canonicalize, - simplify, -) - - -def optimize( - expression: str | exp.Expression, - schema: t.Optional[dict | Schema] = None, - db: t.Optional[str | exp.Identifier] = None, - catalog: t.Optional[str | exp.Identifier] = None, - dialect: DialectType = None, - rules: t.Sequence[t.Callable] = RULES, - **kwargs, -) -> exp.Expression: - """ - Rewrite a sqlglot AST into an optimized form. - - Args: - expression: expression to optimize - schema: database schema. - This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of - the following forms: - 1. {table: {col: type}} - 2. {db: {table: {col: type}}} - 3. {catalog: {db: {table: {col: type}}}} - If no schema is provided then the default schema defined at `sqlgot.schema` will be used - db: specify the default database, as might be set by a `USE DATABASE db` statement - catalog: specify the default catalog, as might be set by a `USE CATALOG c` statement - dialect: The dialect to parse the sql string. - rules: sequence of optimizer rules to use. - Many of the rules require tables and columns to be qualified. - Do not remove `qualify` from the sequence of rules unless you know what you're doing! - **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. - - Returns: - The optimized expression. - """ - schema = ensure_schema(schema, dialect=dialect) - possible_kwargs = { - "db": db, - "catalog": catalog, - "schema": schema, - "dialect": dialect, - "isolate_tables": True, # needed for other optimizations to perform well - "quote_identifiers": False, - **kwargs, - } - - optimized = exp.maybe_parse(expression, dialect=dialect, copy=True) - for rule in rules: - # Find any additional rule parameters, beyond `expression` - rule_params = inspect.getfullargspec(rule).args - rule_kwargs = { - param: possible_kwargs[param] for param in rule_params if param in possible_kwargs - } - optimized = rule(optimized, **rule_kwargs) - - return optimized diff --git a/altimate_packages/sqlglot/optimizer/pushdown_predicates.py b/altimate_packages/sqlglot/optimizer/pushdown_predicates.py deleted file mode 100644 index 6efb63df9..000000000 --- a/altimate_packages/sqlglot/optimizer/pushdown_predicates.py +++ /dev/null @@ -1,222 +0,0 @@ -from sqlglot import exp -from sqlglot.optimizer.normalize import normalized -from sqlglot.optimizer.scope import build_scope, find_in_scope -from sqlglot.optimizer.simplify import simplify -from sqlglot import Dialect - - -def pushdown_predicates(expression, dialect=None): - """ - Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS - - Example: - >>> import sqlglot - >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y WHERE y.a = 1" - >>> expression = sqlglot.parse_one(sql) - >>> pushdown_predicates(expression).sql() - 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x WHERE x.a = 1) AS y WHERE TRUE' - - Args: - expression (sqlglot.Expression): expression to optimize - Returns: - sqlglot.Expression: optimized expression - """ - from sqlglot.dialects.presto import Presto - - root = build_scope(expression) - - dialect = Dialect.get_or_raise(dialect) - unnest_requires_cross_join = isinstance(dialect, Presto) - - if root: - scope_ref_count = root.ref_count() - - for scope in reversed(list(root.traverse())): - select = scope.expression - where = select.args.get("where") - if where: - selected_sources = scope.selected_sources - join_index = { - join.alias_or_name: i for i, join in enumerate(select.args.get("joins") or []) - } - - # a right join can only push down to itself and not the source FROM table - # presto, trino and athena don't support inner joins where the RHS is an UNNEST expression - pushdown_allowed = True - for k, (node, source) in selected_sources.items(): - parent = node.find_ancestor(exp.Join, exp.From) - if isinstance(parent, exp.Join): - if parent.side == "RIGHT": - selected_sources = {k: (node, source)} - break - if isinstance(node, exp.Unnest) and unnest_requires_cross_join: - pushdown_allowed = False - break - - if pushdown_allowed: - pushdown(where.this, selected_sources, scope_ref_count, dialect, join_index) - - # joins should only pushdown into itself, not to other joins - # so we limit the selected sources to only itself - for join in select.args.get("joins") or []: - name = join.alias_or_name - if name in scope.selected_sources: - pushdown( - join.args.get("on"), - {name: scope.selected_sources[name]}, - scope_ref_count, - dialect, - ) - - return expression - - -def pushdown(condition, sources, scope_ref_count, dialect, join_index=None): - if not condition: - return - - condition = condition.replace(simplify(condition, dialect=dialect)) - cnf_like = normalized(condition) or not normalized(condition, dnf=True) - - predicates = list( - condition.flatten() - if isinstance(condition, exp.And if cnf_like else exp.Or) - else [condition] - ) - - if cnf_like: - pushdown_cnf(predicates, sources, scope_ref_count, join_index=join_index) - else: - pushdown_dnf(predicates, sources, scope_ref_count) - - -def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None): - """ - If the predicates are in CNF like form, we can simply replace each block in the parent. - """ - join_index = join_index or {} - for predicate in predicates: - for node in nodes_for_predicate(predicate, sources, scope_ref_count).values(): - if isinstance(node, exp.Join): - name = node.alias_or_name - predicate_tables = exp.column_table_names(predicate, name) - - # Don't push the predicate if it references tables that appear in later joins - this_index = join_index[name] - if all(join_index.get(table, -1) < this_index for table in predicate_tables): - predicate.replace(exp.true()) - node.on(predicate, copy=False) - break - if isinstance(node, exp.Select): - predicate.replace(exp.true()) - inner_predicate = replace_aliases(node, predicate) - if find_in_scope(inner_predicate, exp.AggFunc): - node.having(inner_predicate, copy=False) - else: - node.where(inner_predicate, copy=False) - - -def pushdown_dnf(predicates, sources, scope_ref_count): - """ - If the predicates are in DNF form, we can only push down conditions that are in all blocks. - Additionally, we can't remove predicates from their original form. - """ - # find all the tables that can be pushdown too - # these are tables that are referenced in all blocks of a DNF - # (a.x AND b.x) OR (a.y AND c.y) - # only table a can be push down - pushdown_tables = set() - - for a in predicates: - a_tables = exp.column_table_names(a) - - for b in predicates: - a_tables &= exp.column_table_names(b) - - pushdown_tables.update(a_tables) - - conditions = {} - - # pushdown all predicates to their respective nodes - for table in sorted(pushdown_tables): - for predicate in predicates: - nodes = nodes_for_predicate(predicate, sources, scope_ref_count) - - if table not in nodes: - continue - - conditions[table] = ( - exp.or_(conditions[table], predicate) if table in conditions else predicate - ) - - for name, node in nodes.items(): - if name not in conditions: - continue - - predicate = conditions[name] - - if isinstance(node, exp.Join): - node.on(predicate, copy=False) - elif isinstance(node, exp.Select): - inner_predicate = replace_aliases(node, predicate) - if find_in_scope(inner_predicate, exp.AggFunc): - node.having(inner_predicate, copy=False) - else: - node.where(inner_predicate, copy=False) - - -def nodes_for_predicate(predicate, sources, scope_ref_count): - nodes = {} - tables = exp.column_table_names(predicate) - where_condition = isinstance(predicate.find_ancestor(exp.Join, exp.Where), exp.Where) - - for table in sorted(tables): - node, source = sources.get(table) or (None, None) - - # if the predicate is in a where statement we can try to push it down - # we want to find the root join or from statement - if node and where_condition: - node = node.find_ancestor(exp.Join, exp.From) - - # a node can reference a CTE which should be pushed down - if isinstance(node, exp.From) and not isinstance(source, exp.Table): - with_ = source.parent.expression.args.get("with") - if with_ and with_.recursive: - return {} - node = source.expression - - if isinstance(node, exp.Join): - if node.side and node.side != "RIGHT": - return {} - nodes[table] = node - elif isinstance(node, exp.Select) and len(tables) == 1: - # We can't push down window expressions - has_window_expression = any( - select for select in node.selects if select.find(exp.Window) - ) - # we can't push down predicates to select statements if they are referenced in - # multiple places. - if ( - not node.args.get("group") - and scope_ref_count[id(source)] < 2 - and not has_window_expression - ): - nodes[table] = node - return nodes - - -def replace_aliases(source, predicate): - aliases = {} - - for select in source.selects: - if isinstance(select, exp.Alias): - aliases[select.alias] = select.this - else: - aliases[select.name] = select - - def _replace_alias(column): - if isinstance(column, exp.Column) and column.name in aliases: - return aliases[column.name].copy() - return column - - return predicate.transform(_replace_alias) diff --git a/altimate_packages/sqlglot/optimizer/pushdown_projections.py b/altimate_packages/sqlglot/optimizer/pushdown_projections.py deleted file mode 100644 index ed22ce1b3..000000000 --- a/altimate_packages/sqlglot/optimizer/pushdown_projections.py +++ /dev/null @@ -1,172 +0,0 @@ -from __future__ import annotations - -import typing as t -from collections import defaultdict - -from sqlglot import alias, exp -from sqlglot.optimizer.qualify_columns import Resolver -from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.schema import ensure_schema -from sqlglot.errors import OptimizeError -from sqlglot.helper import seq_get - -if t.TYPE_CHECKING: - from sqlglot._typing import E - from sqlglot.schema import Schema - from sqlglot.dialects.dialect import DialectType - -# Sentinel value that means an outer query selecting ALL columns -SELECT_ALL = object() - - -# Selection to use if selection list is empty -def default_selection(is_agg: bool) -> exp.Alias: - return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_") - - -def pushdown_projections( - expression: E, - schema: t.Optional[t.Dict | Schema] = None, - remove_unused_selections: bool = True, - dialect: DialectType = None, -) -> E: - """ - Rewrite sqlglot AST to remove unused columns projections. - - Example: - >>> import sqlglot - >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" - >>> expression = sqlglot.parse_one(sql) - >>> pushdown_projections(expression).sql() - 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' - - Args: - expression (sqlglot.Expression): expression to optimize - remove_unused_selections (bool): remove selects that are unused - Returns: - sqlglot.Expression: optimized expression - """ - # Map of Scope to all columns being selected by outer queries. - schema = ensure_schema(schema, dialect=dialect) - source_column_alias_count: t.Dict[exp.Expression | Scope, int] = {} - referenced_columns: t.DefaultDict[Scope, t.Set[str | object]] = defaultdict(set) - - # We build the scope tree (which is traversed in DFS postorder), then iterate - # over the result in reverse order. This should ensure that the set of selected - # columns for a particular scope are completely build by the time we get to it. - for scope in reversed(traverse_scope(expression)): - parent_selections = referenced_columns.get(scope, {SELECT_ALL}) - alias_count = source_column_alias_count.get(scope, 0) - - # We can't remove columns SELECT DISTINCT nor UNION DISTINCT. - if scope.expression.args.get("distinct"): - parent_selections = {SELECT_ALL} - - if isinstance(scope.expression, exp.SetOperation): - set_op = scope.expression - if not (set_op.kind or set_op.side): - # Do not optimize this set operation if it's using the BigQuery specific - # kind / side syntax (e.g INNER UNION ALL BY NAME) which changes the semantics of the operation - left, right = scope.union_scopes - if len(left.expression.selects) != len(right.expression.selects): - scope_sql = scope.expression.sql(dialect=dialect) - raise OptimizeError( - f"Invalid set operation due to column mismatch: {scope_sql}." - ) - - referenced_columns[left] = parent_selections - - if any(select.is_star for select in right.expression.selects): - referenced_columns[right] = parent_selections - elif not any(select.is_star for select in left.expression.selects): - if scope.expression.args.get("by_name"): - referenced_columns[right] = referenced_columns[left] - else: - referenced_columns[right] = { - right.expression.selects[i].alias_or_name - for i, select in enumerate(left.expression.selects) - if SELECT_ALL in parent_selections - or select.alias_or_name in parent_selections - } - - if isinstance(scope.expression, exp.Select): - if remove_unused_selections: - _remove_unused_selections(scope, parent_selections, schema, alias_count) - - if scope.expression.is_star: - continue - - # Group columns by source name - selects = defaultdict(set) - for col in scope.columns: - table_name = col.table - col_name = col.name - selects[table_name].add(col_name) - - # Push the selected columns down to the next scope - for name, (node, source) in scope.selected_sources.items(): - if isinstance(source, Scope): - select = seq_get(source.expression.selects, 0) - - if scope.pivots or isinstance(select, exp.QueryTransform): - columns = {SELECT_ALL} - else: - columns = selects.get(name) or set() - - referenced_columns[source].update(columns) - - column_aliases = node.alias_column_names - if column_aliases: - source_column_alias_count[source] = len(column_aliases) - - return expression - - -def _remove_unused_selections(scope, parent_selections, schema, alias_count): - order = scope.expression.args.get("order") - - if order: - # Assume columns without a qualified table are references to output columns - order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} - else: - order_refs = set() - - new_selections = [] - removed = False - star = False - is_agg = False - - select_all = SELECT_ALL in parent_selections - - for selection in scope.expression.selects: - name = selection.alias_or_name - - if select_all or name in parent_selections or name in order_refs or alias_count > 0: - new_selections.append(selection) - alias_count -= 1 - else: - if selection.is_star: - star = True - removed = True - - if not is_agg and selection.find(exp.AggFunc): - is_agg = True - - if star: - resolver = Resolver(scope, schema) - names = {s.alias_or_name for s in new_selections} - - for name in sorted(parent_selections): - if name not in names: - new_selections.append( - alias(exp.column(name, table=resolver.get_table(name)), name, copy=False) - ) - - # If there are no remaining selections, just select a single constant - if not new_selections: - new_selections.append(default_selection(is_agg)) - - scope.expression.select(*new_selections, append=False, copy=False) - - if removed: - scope.clear_cache() diff --git a/altimate_packages/sqlglot/optimizer/qualify.py b/altimate_packages/sqlglot/optimizer/qualify.py deleted file mode 100644 index a35699d21..000000000 --- a/altimate_packages/sqlglot/optimizer/qualify.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import exp -from sqlglot.dialects.dialect import Dialect, DialectType -from sqlglot.optimizer.isolate_table_selects import isolate_table_selects -from sqlglot.optimizer.normalize_identifiers import normalize_identifiers -from sqlglot.optimizer.qualify_columns import ( - pushdown_cte_alias_columns as pushdown_cte_alias_columns_func, - qualify_columns as qualify_columns_func, - quote_identifiers as quote_identifiers_func, - validate_qualify_columns as validate_qualify_columns_func, -) -from sqlglot.optimizer.qualify_tables import qualify_tables -from sqlglot.schema import Schema, ensure_schema - - -def qualify( - expression: exp.Expression, - dialect: DialectType = None, - db: t.Optional[str] = None, - catalog: t.Optional[str] = None, - schema: t.Optional[dict | Schema] = None, - expand_alias_refs: bool = True, - expand_stars: bool = True, - infer_schema: t.Optional[bool] = None, - isolate_tables: bool = False, - qualify_columns: bool = True, - allow_partial_qualification: bool = False, - validate_qualify_columns: bool = True, - quote_identifiers: bool = True, - identify: bool = True, - infer_csv_schemas: bool = False, -) -> exp.Expression: - """ - Rewrite sqlglot AST to have normalized and qualified tables and columns. - - This step is necessary for all further SQLGlot optimizations. - - Example: - >>> import sqlglot - >>> schema = {"tbl": {"col": "INT"}} - >>> expression = sqlglot.parse_one("SELECT col FROM tbl") - >>> qualify(expression, schema=schema).sql() - 'SELECT "tbl"."col" AS "col" FROM "tbl" AS "tbl"' - - Args: - expression: Expression to qualify. - db: Default database name for tables. - catalog: Default catalog name for tables. - schema: Schema to infer column names and types. - expand_alias_refs: Whether to expand references to aliases. - expand_stars: Whether to expand star queries. This is a necessary step - for most of the optimizer's rules to work; do not set to False unless you - know what you're doing! - infer_schema: Whether to infer the schema if missing. - isolate_tables: Whether to isolate table selects. - qualify_columns: Whether to qualify columns. - allow_partial_qualification: Whether to allow partial qualification. - validate_qualify_columns: Whether to validate columns. - quote_identifiers: Whether to run the quote_identifiers step. - This step is necessary to ensure correctness for case sensitive queries. - But this flag is provided in case this step is performed at a later time. - identify: If True, quote all identifiers, else only necessary ones. - infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas. - - Returns: - The qualified expression. - """ - schema = ensure_schema(schema, dialect=dialect) - expression = qualify_tables( - expression, - db=db, - catalog=catalog, - schema=schema, - dialect=dialect, - infer_csv_schemas=infer_csv_schemas, - ) - expression = normalize_identifiers(expression, dialect=dialect) - - if isolate_tables: - expression = isolate_table_selects(expression, schema=schema) - - if Dialect.get_or_raise(dialect).PREFER_CTE_ALIAS_COLUMN: - expression = pushdown_cte_alias_columns_func(expression) - - if qualify_columns: - expression = qualify_columns_func( - expression, - schema, - expand_alias_refs=expand_alias_refs, - expand_stars=expand_stars, - infer_schema=infer_schema, - allow_partial_qualification=allow_partial_qualification, - ) - - if quote_identifiers: - expression = quote_identifiers_func(expression, dialect=dialect, identify=identify) - - if validate_qualify_columns: - validate_qualify_columns_func(expression) - - return expression diff --git a/altimate_packages/sqlglot/optimizer/qualify_columns.py b/altimate_packages/sqlglot/optimizer/qualify_columns.py deleted file mode 100644 index dedd973d8..000000000 --- a/altimate_packages/sqlglot/optimizer/qualify_columns.py +++ /dev/null @@ -1,1024 +0,0 @@ -from __future__ import annotations - -import itertools -import typing as t - -from sqlglot import alias, exp -from sqlglot.dialects.dialect import Dialect, DialectType -from sqlglot.errors import OptimizeError -from sqlglot.helper import seq_get, SingleValuedMapping -from sqlglot.optimizer.annotate_types import TypeAnnotator -from sqlglot.optimizer.scope import Scope, build_scope, traverse_scope, walk_in_scope -from sqlglot.optimizer.simplify import simplify_parens -from sqlglot.schema import Schema, ensure_schema - -if t.TYPE_CHECKING: - from sqlglot._typing import E - - -def qualify_columns( - expression: exp.Expression, - schema: t.Dict | Schema, - expand_alias_refs: bool = True, - expand_stars: bool = True, - infer_schema: t.Optional[bool] = None, - allow_partial_qualification: bool = False, - dialect: DialectType = None, -) -> exp.Expression: - """ - Rewrite sqlglot AST to have fully qualified columns. - - Example: - >>> import sqlglot - >>> schema = {"tbl": {"col": "INT"}} - >>> expression = sqlglot.parse_one("SELECT col FROM tbl") - >>> qualify_columns(expression, schema).sql() - 'SELECT tbl.col AS col FROM tbl' - - Args: - expression: Expression to qualify. - schema: Database schema. - expand_alias_refs: Whether to expand references to aliases. - expand_stars: Whether to expand star queries. This is a necessary step - for most of the optimizer's rules to work; do not set to False unless you - know what you're doing! - infer_schema: Whether to infer the schema if missing. - allow_partial_qualification: Whether to allow partial qualification. - - Returns: - The qualified expression. - - Notes: - - Currently only handles a single PIVOT or UNPIVOT operator - """ - schema = ensure_schema(schema, dialect=dialect) - annotator = TypeAnnotator(schema) - infer_schema = schema.empty if infer_schema is None else infer_schema - dialect = Dialect.get_or_raise(schema.dialect) - pseudocolumns = dialect.PSEUDOCOLUMNS - bigquery = dialect == "bigquery" - - for scope in traverse_scope(expression): - scope_expression = scope.expression - is_select = isinstance(scope_expression, exp.Select) - - if is_select and scope_expression.args.get("connect"): - # In Snowflake / Oracle queries that have a CONNECT BY clause, one can use the LEVEL - # pseudocolumn, which doesn't belong to a table, so we change it into an identifier - scope_expression.transform( - lambda n: n.this if isinstance(n, exp.Column) and n.name == "LEVEL" else n, - copy=False, - ) - scope.clear_cache() - - resolver = Resolver(scope, schema, infer_schema=infer_schema) - _pop_table_column_aliases(scope.ctes) - _pop_table_column_aliases(scope.derived_tables) - using_column_tables = _expand_using(scope, resolver) - - if (schema.empty or dialect.FORCE_EARLY_ALIAS_REF_EXPANSION) and expand_alias_refs: - _expand_alias_refs( - scope, - resolver, - dialect, - expand_only_groupby=bigquery, - ) - - _convert_columns_to_dots(scope, resolver) - _qualify_columns(scope, resolver, allow_partial_qualification=allow_partial_qualification) - - if not schema.empty and expand_alias_refs: - _expand_alias_refs(scope, resolver, dialect) - - if is_select: - if expand_stars: - _expand_stars( - scope, - resolver, - using_column_tables, - pseudocolumns, - annotator, - ) - qualify_outputs(scope) - - _expand_group_by(scope, dialect) - - # DISTINCT ON and ORDER BY follow the same rules (tested in DuckDB, Postgres, ClickHouse) - # https://www.postgresql.org/docs/current/sql-select.html#SQL-DISTINCT - _expand_order_by_and_distinct_on(scope, resolver) - - if bigquery: - annotator.annotate_scope(scope) - - return expression - - -def validate_qualify_columns(expression: E) -> E: - """Raise an `OptimizeError` if any columns aren't qualified""" - all_unqualified_columns = [] - for scope in traverse_scope(expression): - if isinstance(scope.expression, exp.Select): - unqualified_columns = scope.unqualified_columns - - if scope.external_columns and not scope.is_correlated_subquery and not scope.pivots: - column = scope.external_columns[0] - for_table = f" for table: '{column.table}'" if column.table else "" - raise OptimizeError(f"Column '{column}' could not be resolved{for_table}") - - if unqualified_columns and scope.pivots and scope.pivots[0].unpivot: - # New columns produced by the UNPIVOT can't be qualified, but there may be columns - # under the UNPIVOT's IN clause that can and should be qualified. We recompute - # this list here to ensure those in the former category will be excluded. - unpivot_columns = set(_unpivot_columns(scope.pivots[0])) - unqualified_columns = [c for c in unqualified_columns if c not in unpivot_columns] - - all_unqualified_columns.extend(unqualified_columns) - - if all_unqualified_columns: - raise OptimizeError(f"Ambiguous columns: {all_unqualified_columns}") - - return expression - - -def _unpivot_columns(unpivot: exp.Pivot) -> t.Iterator[exp.Column]: - name_columns = [ - field.this - for field in unpivot.fields - if isinstance(field, exp.In) and isinstance(field.this, exp.Column) - ] - value_columns = (c for e in unpivot.expressions for c in e.find_all(exp.Column)) - - return itertools.chain(name_columns, value_columns) - - -def _pop_table_column_aliases(derived_tables: t.List[exp.CTE | exp.Subquery]) -> None: - """ - Remove table column aliases. - - For example, `col1` and `col2` will be dropped in SELECT ... FROM (SELECT ...) AS foo(col1, col2) - """ - for derived_table in derived_tables: - if isinstance(derived_table.parent, exp.With) and derived_table.parent.recursive: - continue - table_alias = derived_table.args.get("alias") - if table_alias: - table_alias.args.pop("columns", None) - - -def _expand_using(scope: Scope, resolver: Resolver) -> t.Dict[str, t.Any]: - columns = {} - - def _update_source_columns(source_name: str) -> None: - for column_name in resolver.get_source_columns(source_name): - if column_name not in columns: - columns[column_name] = source_name - - joins = list(scope.find_all(exp.Join)) - names = {join.alias_or_name for join in joins} - ordered = [key for key in scope.selected_sources if key not in names] - - if names and not ordered: - raise OptimizeError(f"Joins {names} missing source table {scope.expression}") - - # Mapping of automatically joined column names to an ordered set of source names (dict). - column_tables: t.Dict[str, t.Dict[str, t.Any]] = {} - - for source_name in ordered: - _update_source_columns(source_name) - - for i, join in enumerate(joins): - source_table = ordered[-1] - if source_table: - _update_source_columns(source_table) - - join_table = join.alias_or_name - ordered.append(join_table) - - using = join.args.get("using") - if not using: - continue - - join_columns = resolver.get_source_columns(join_table) - conditions = [] - using_identifier_count = len(using) - is_semi_or_anti_join = join.is_semi_or_anti_join - - for identifier in using: - identifier = identifier.name - table = columns.get(identifier) - - if not table or identifier not in join_columns: - if (columns and "*" not in columns) and join_columns: - raise OptimizeError(f"Cannot automatically join: {identifier}") - - table = table or source_table - - if i == 0 or using_identifier_count == 1: - lhs: exp.Expression = exp.column(identifier, table=table) - else: - coalesce_columns = [ - exp.column(identifier, table=t) - for t in ordered[:-1] - if identifier in resolver.get_source_columns(t) - ] - if len(coalesce_columns) > 1: - lhs = exp.func("coalesce", *coalesce_columns) - else: - lhs = exp.column(identifier, table=table) - - conditions.append(lhs.eq(exp.column(identifier, table=join_table))) - - # Set all values in the dict to None, because we only care about the key ordering - tables = column_tables.setdefault(identifier, {}) - - # Do not update the dict if this was a SEMI/ANTI join in - # order to avoid generating COALESCE columns for this join pair - if not is_semi_or_anti_join: - if table not in tables: - tables[table] = None - if join_table not in tables: - tables[join_table] = None - - join.args.pop("using") - join.set("on", exp.and_(*conditions, copy=False)) - - if column_tables: - for column in scope.columns: - if not column.table and column.name in column_tables: - tables = column_tables[column.name] - coalesce_args = [exp.column(column.name, table=table) for table in tables] - replacement: exp.Expression = exp.func("coalesce", *coalesce_args) - - if isinstance(column.parent, exp.Select): - # Ensure the USING column keeps its name if it's projected - replacement = alias(replacement, alias=column.name, copy=False) - elif isinstance(column.parent, exp.Struct): - # Ensure the USING column keeps its name if it's an anonymous STRUCT field - replacement = exp.PropertyEQ( - this=exp.to_identifier(column.name), expression=replacement - ) - - scope.replace(column, replacement) - - return column_tables - - -def _expand_alias_refs( - scope: Scope, resolver: Resolver, dialect: Dialect, expand_only_groupby: bool = False -) -> None: - """ - Expand references to aliases. - Example: - SELECT y.foo AS bar, bar * 2 AS baz FROM y - => SELECT y.foo AS bar, y.foo * 2 AS baz FROM y - """ - expression = scope.expression - - if not isinstance(expression, exp.Select) or dialect == "oracle": - return - - alias_to_expression: t.Dict[str, t.Tuple[exp.Expression, int]] = {} - projections = {s.alias_or_name for s in expression.selects} - - def replace_columns( - node: t.Optional[exp.Expression], resolve_table: bool = False, literal_index: bool = False - ) -> None: - is_group_by = isinstance(node, exp.Group) - is_having = isinstance(node, exp.Having) - if not node or (expand_only_groupby and not is_group_by): - return - - for column in walk_in_scope(node, prune=lambda node: node.is_star): - if not isinstance(column, exp.Column): - continue - - # BigQuery's GROUP BY allows alias expansion only for standalone names, e.g: - # SELECT FUNC(col) AS col FROM t GROUP BY col --> Can be expanded - # SELECT FUNC(col) AS col FROM t GROUP BY FUNC(col) --> Shouldn't be expanded, will result to FUNC(FUNC(col)) - # This not required for the HAVING clause as it can evaluate expressions using both the alias & the table columns - if expand_only_groupby and is_group_by and column.parent is not node: - continue - - skip_replace = False - table = resolver.get_table(column.name) if resolve_table and not column.table else None - alias_expr, i = alias_to_expression.get(column.name, (None, 1)) - - if alias_expr: - skip_replace = bool( - alias_expr.find(exp.AggFunc) - and column.find_ancestor(exp.AggFunc) - and not isinstance(column.find_ancestor(exp.Window, exp.Select), exp.Window) - ) - - # BigQuery's having clause gets confused if an alias matches a source. - # SELECT x.a, max(x.b) as x FROM x GROUP BY 1 HAVING x > 1; - # If HAVING x is expanded to max(x.b), bigquery treats x as the new projection x instead of the table - if is_having and dialect == "bigquery": - skip_replace = skip_replace or any( - node.parts[0].name in projections - for node in alias_expr.find_all(exp.Column) - ) - - if table and (not alias_expr or skip_replace): - column.set("table", table) - elif not column.table and alias_expr and not skip_replace: - if isinstance(alias_expr, exp.Literal) and (literal_index or resolve_table): - if literal_index: - column.replace(exp.Literal.number(i)) - else: - column = column.replace(exp.paren(alias_expr)) - simplified = simplify_parens(column) - if simplified is not column: - column.replace(simplified) - - for i, projection in enumerate(expression.selects): - replace_columns(projection) - if isinstance(projection, exp.Alias): - alias_to_expression[projection.alias] = (projection.this, i + 1) - - parent_scope = scope - while parent_scope.is_union: - parent_scope = parent_scope.parent - - # We shouldn't expand aliases if they match the recursive CTE's columns - if parent_scope.is_cte: - cte = parent_scope.expression.parent - if cte.find_ancestor(exp.With).recursive: - for recursive_cte_column in cte.args["alias"].columns or cte.this.selects: - alias_to_expression.pop(recursive_cte_column.output_name, None) - - replace_columns(expression.args.get("where")) - replace_columns(expression.args.get("group"), literal_index=True) - replace_columns(expression.args.get("having"), resolve_table=True) - replace_columns(expression.args.get("qualify"), resolve_table=True) - - # Snowflake allows alias expansion in the JOIN ... ON clause (and almost everywhere else) - # https://docs.snowflake.com/en/sql-reference/sql/select#usage-notes - if dialect == "snowflake": - for join in expression.args.get("joins") or []: - replace_columns(join) - - scope.clear_cache() - - -def _expand_group_by(scope: Scope, dialect: DialectType) -> None: - expression = scope.expression - group = expression.args.get("group") - if not group: - return - - group.set("expressions", _expand_positional_references(scope, group.expressions, dialect)) - expression.set("group", group) - - -def _expand_order_by_and_distinct_on(scope: Scope, resolver: Resolver) -> None: - for modifier_key in ("order", "distinct"): - modifier = scope.expression.args.get(modifier_key) - if isinstance(modifier, exp.Distinct): - modifier = modifier.args.get("on") - - if not isinstance(modifier, exp.Expression): - continue - - modifier_expressions = modifier.expressions - if modifier_key == "order": - modifier_expressions = [ordered.this for ordered in modifier_expressions] - - for original, expanded in zip( - modifier_expressions, - _expand_positional_references( - scope, modifier_expressions, resolver.schema.dialect, alias=True - ), - ): - for agg in original.find_all(exp.AggFunc): - for col in agg.find_all(exp.Column): - if not col.table: - col.set("table", resolver.get_table(col.name)) - - original.replace(expanded) - - if scope.expression.args.get("group"): - selects = {s.this: exp.column(s.alias_or_name) for s in scope.expression.selects} - - for expression in modifier_expressions: - expression.replace( - exp.to_identifier(_select_by_pos(scope, expression).alias) - if expression.is_int - else selects.get(expression, expression) - ) - - -def _expand_positional_references( - scope: Scope, expressions: t.Iterable[exp.Expression], dialect: DialectType, alias: bool = False -) -> t.List[exp.Expression]: - new_nodes: t.List[exp.Expression] = [] - ambiguous_projections = None - - for node in expressions: - if node.is_int: - select = _select_by_pos(scope, t.cast(exp.Literal, node)) - - if alias: - new_nodes.append(exp.column(select.args["alias"].copy())) - else: - select = select.this - - if dialect == "bigquery": - if ambiguous_projections is None: - # When a projection name is also a source name and it is referenced in the - # GROUP BY clause, BQ can't understand what the identifier corresponds to - ambiguous_projections = { - s.alias_or_name - for s in scope.expression.selects - if s.alias_or_name in scope.selected_sources - } - - ambiguous = any( - column.parts[0].name in ambiguous_projections - for column in select.find_all(exp.Column) - ) - else: - ambiguous = False - - if ( - isinstance(select, exp.CONSTANTS) - or select.find(exp.Explode, exp.Unnest) - or ambiguous - ): - new_nodes.append(node) - else: - new_nodes.append(select.copy()) - else: - new_nodes.append(node) - - return new_nodes - - -def _select_by_pos(scope: Scope, node: exp.Literal) -> exp.Alias: - try: - return scope.expression.selects[int(node.this) - 1].assert_is(exp.Alias) - except IndexError: - raise OptimizeError(f"Unknown output column: {node.name}") - - -def _convert_columns_to_dots(scope: Scope, resolver: Resolver) -> None: - """ - Converts `Column` instances that represent struct field lookup into chained `Dots`. - - Struct field lookups look like columns (e.g. "struct"."field"), but they need to be - qualified separately and represented as Dot(Dot(...(., field1), field2, ...)). - """ - converted = False - for column in itertools.chain(scope.columns, scope.stars): - if isinstance(column, exp.Dot): - continue - - column_table: t.Optional[str | exp.Identifier] = column.table - if ( - column_table - and column_table not in scope.sources - and ( - not scope.parent - or column_table not in scope.parent.sources - or not scope.is_correlated_subquery - ) - ): - root, *parts = column.parts - - if root.name in scope.sources: - # The struct is already qualified, but we still need to change the AST - column_table = root - root, *parts = parts - else: - column_table = resolver.get_table(root.name) - - if column_table: - converted = True - column.replace(exp.Dot.build([exp.column(root, table=column_table), *parts])) - - if converted: - # We want to re-aggregate the converted columns, otherwise they'd be skipped in - # a `for column in scope.columns` iteration, even though they shouldn't be - scope.clear_cache() - - -def _qualify_columns(scope: Scope, resolver: Resolver, allow_partial_qualification: bool) -> None: - """Disambiguate columns, ensuring each column specifies a source""" - for column in scope.columns: - column_table = column.table - column_name = column.name - - if column_table and column_table in scope.sources: - source_columns = resolver.get_source_columns(column_table) - if ( - not allow_partial_qualification - and source_columns - and column_name not in source_columns - and "*" not in source_columns - ): - raise OptimizeError(f"Unknown column: {column_name}") - - if not column_table: - if scope.pivots and not column.find_ancestor(exp.Pivot): - # If the column is under the Pivot expression, we need to qualify it - # using the name of the pivoted source instead of the pivot's alias - column.set("table", exp.to_identifier(scope.pivots[0].alias)) - continue - - # column_table can be a '' because bigquery unnest has no table alias - column_table = resolver.get_table(column_name) - if column_table: - column.set("table", column_table) - - for pivot in scope.pivots: - for column in pivot.find_all(exp.Column): - if not column.table and column.name in resolver.all_columns: - column_table = resolver.get_table(column.name) - if column_table: - column.set("table", column_table) - - -def _expand_struct_stars( - expression: exp.Dot, -) -> t.List[exp.Alias]: - """[BigQuery] Expand/Flatten foo.bar.* where bar is a struct column""" - - dot_column = t.cast(exp.Column, expression.find(exp.Column)) - if not dot_column.is_type(exp.DataType.Type.STRUCT): - return [] - - # All nested struct values are ColumnDefs, so normalize the first exp.Column in one - dot_column = dot_column.copy() - starting_struct = exp.ColumnDef(this=dot_column.this, kind=dot_column.type) - - # First part is the table name and last part is the star so they can be dropped - dot_parts = expression.parts[1:-1] - - # If we're expanding a nested struct eg. t.c.f1.f2.* find the last struct (f2 in this case) - for part in dot_parts[1:]: - for field in t.cast(exp.DataType, starting_struct.kind).expressions: - # Unable to expand star unless all fields are named - if not isinstance(field.this, exp.Identifier): - return [] - - if field.name == part.name and field.kind.is_type(exp.DataType.Type.STRUCT): - starting_struct = field - break - else: - # There is no matching field in the struct - return [] - - taken_names = set() - new_selections = [] - - for field in t.cast(exp.DataType, starting_struct.kind).expressions: - name = field.name - - # Ambiguous or anonymous fields can't be expanded - if name in taken_names or not isinstance(field.this, exp.Identifier): - return [] - - taken_names.add(name) - - this = field.this.copy() - root, *parts = [part.copy() for part in itertools.chain(dot_parts, [this])] - new_column = exp.column( - t.cast(exp.Identifier, root), - table=dot_column.args.get("table"), - fields=t.cast(t.List[exp.Identifier], parts), - ) - new_selections.append(alias(new_column, this, copy=False)) - - return new_selections - - -def _expand_stars( - scope: Scope, - resolver: Resolver, - using_column_tables: t.Dict[str, t.Any], - pseudocolumns: t.Set[str], - annotator: TypeAnnotator, -) -> None: - """Expand stars to lists of column selections""" - - new_selections: t.List[exp.Expression] = [] - except_columns: t.Dict[int, t.Set[str]] = {} - replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] = {} - rename_columns: t.Dict[int, t.Dict[str, str]] = {} - - coalesced_columns = set() - dialect = resolver.schema.dialect - - pivot_output_columns = None - pivot_exclude_columns: t.Set[str] = set() - - pivot = t.cast(t.Optional[exp.Pivot], seq_get(scope.pivots, 0)) - if isinstance(pivot, exp.Pivot) and not pivot.alias_column_names: - if pivot.unpivot: - pivot_output_columns = [c.output_name for c in _unpivot_columns(pivot)] - - for field in pivot.fields: - if isinstance(field, exp.In): - pivot_exclude_columns.update( - c.output_name for e in field.expressions for c in e.find_all(exp.Column) - ) - - else: - pivot_exclude_columns = set(c.output_name for c in pivot.find_all(exp.Column)) - - pivot_output_columns = [c.output_name for c in pivot.args.get("columns", [])] - if not pivot_output_columns: - pivot_output_columns = [c.alias_or_name for c in pivot.expressions] - - is_bigquery = dialect == "bigquery" - if is_bigquery and any(isinstance(col, exp.Dot) for col in scope.stars): - # Found struct expansion, annotate scope ahead of time - annotator.annotate_scope(scope) - - for expression in scope.expression.selects: - tables = [] - if isinstance(expression, exp.Star): - tables.extend(scope.selected_sources) - _add_except_columns(expression, tables, except_columns) - _add_replace_columns(expression, tables, replace_columns) - _add_rename_columns(expression, tables, rename_columns) - elif expression.is_star: - if not isinstance(expression, exp.Dot): - tables.append(expression.table) - _add_except_columns(expression.this, tables, except_columns) - _add_replace_columns(expression.this, tables, replace_columns) - _add_rename_columns(expression.this, tables, rename_columns) - elif is_bigquery: - struct_fields = _expand_struct_stars(expression) - if struct_fields: - new_selections.extend(struct_fields) - continue - - if not tables: - new_selections.append(expression) - continue - - for table in tables: - if table not in scope.sources: - raise OptimizeError(f"Unknown table: {table}") - - columns = resolver.get_source_columns(table, only_visible=True) - columns = columns or scope.outer_columns - - if pseudocolumns: - columns = [name for name in columns if name.upper() not in pseudocolumns] - - if not columns or "*" in columns: - return - - table_id = id(table) - columns_to_exclude = except_columns.get(table_id) or set() - renamed_columns = rename_columns.get(table_id, {}) - replaced_columns = replace_columns.get(table_id, {}) - - if pivot: - if pivot_output_columns and pivot_exclude_columns: - pivot_columns = [c for c in columns if c not in pivot_exclude_columns] - pivot_columns.extend(pivot_output_columns) - else: - pivot_columns = pivot.alias_column_names - - if pivot_columns: - new_selections.extend( - alias(exp.column(name, table=pivot.alias), name, copy=False) - for name in pivot_columns - if name not in columns_to_exclude - ) - continue - - for name in columns: - if name in columns_to_exclude or name in coalesced_columns: - continue - if name in using_column_tables and table in using_column_tables[name]: - coalesced_columns.add(name) - tables = using_column_tables[name] - coalesce_args = [exp.column(name, table=table) for table in tables] - - new_selections.append( - alias(exp.func("coalesce", *coalesce_args), alias=name, copy=False) - ) - else: - alias_ = renamed_columns.get(name, name) - selection_expr = replaced_columns.get(name) or exp.column(name, table=table) - new_selections.append( - alias(selection_expr, alias_, copy=False) - if alias_ != name - else selection_expr - ) - - # Ensures we don't overwrite the initial selections with an empty list - if new_selections and isinstance(scope.expression, exp.Select): - scope.expression.set("expressions", new_selections) - - -def _add_except_columns( - expression: exp.Expression, tables, except_columns: t.Dict[int, t.Set[str]] -) -> None: - except_ = expression.args.get("except") - - if not except_: - return - - columns = {e.name for e in except_} - - for table in tables: - except_columns[id(table)] = columns - - -def _add_rename_columns( - expression: exp.Expression, tables, rename_columns: t.Dict[int, t.Dict[str, str]] -) -> None: - rename = expression.args.get("rename") - - if not rename: - return - - columns = {e.this.name: e.alias for e in rename} - - for table in tables: - rename_columns[id(table)] = columns - - -def _add_replace_columns( - expression: exp.Expression, tables, replace_columns: t.Dict[int, t.Dict[str, exp.Alias]] -) -> None: - replace = expression.args.get("replace") - - if not replace: - return - - columns = {e.alias: e for e in replace} - - for table in tables: - replace_columns[id(table)] = columns - - -def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None: - """Ensure all output columns are aliased""" - if isinstance(scope_or_expression, exp.Expression): - scope = build_scope(scope_or_expression) - if not isinstance(scope, Scope): - return - else: - scope = scope_or_expression - - new_selections = [] - for i, (selection, aliased_column) in enumerate( - itertools.zip_longest(scope.expression.selects, scope.outer_columns) - ): - if selection is None or isinstance(selection, exp.QueryTransform): - break - - if isinstance(selection, exp.Subquery): - if not selection.output_name: - selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) - elif not isinstance(selection, exp.Alias) and not selection.is_star: - selection = alias( - selection, - alias=selection.output_name or f"_col_{i}", - copy=False, - ) - if aliased_column: - selection.set("alias", exp.to_identifier(aliased_column)) - - new_selections.append(selection) - - if new_selections and isinstance(scope.expression, exp.Select): - scope.expression.set("expressions", new_selections) - - -def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool = True) -> E: - """Makes sure all identifiers that need to be quoted are quoted.""" - return expression.transform( - Dialect.get_or_raise(dialect).quote_identifier, identify=identify, copy=False - ) # type: ignore - - -def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression: - """ - Pushes down the CTE alias columns into the projection, - - This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y") - >>> pushdown_cte_alias_columns(expression).sql() - 'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y' - - Args: - expression: Expression to pushdown. - - Returns: - The expression with the CTE aliases pushed down into the projection. - """ - for cte in expression.find_all(exp.CTE): - if cte.alias_column_names: - new_expressions = [] - for _alias, projection in zip(cte.alias_column_names, cte.this.expressions): - if isinstance(projection, exp.Alias): - projection.set("alias", _alias) - else: - projection = alias(projection, alias=_alias) - new_expressions.append(projection) - cte.this.set("expressions", new_expressions) - - return expression - - -class Resolver: - """ - Helper for resolving columns. - - This is a class so we can lazily load some things and easily share them across functions. - """ - - def __init__(self, scope: Scope, schema: Schema, infer_schema: bool = True): - self.scope = scope - self.schema = schema - self._source_columns: t.Optional[t.Dict[str, t.Sequence[str]]] = None - self._unambiguous_columns: t.Optional[t.Mapping[str, str]] = None - self._all_columns: t.Optional[t.Set[str]] = None - self._infer_schema = infer_schema - self._get_source_columns_cache: t.Dict[t.Tuple[str, bool], t.Sequence[str]] = {} - - def get_table(self, column_name: str) -> t.Optional[exp.Identifier]: - """ - Get the table for a column name. - - Args: - column_name: The column name to find the table for. - Returns: - The table name if it can be found/inferred. - """ - if self._unambiguous_columns is None: - self._unambiguous_columns = self._get_unambiguous_columns( - self._get_all_source_columns() - ) - - table_name = self._unambiguous_columns.get(column_name) - - if not table_name and self._infer_schema: - sources_without_schema = tuple( - source - for source, columns in self._get_all_source_columns().items() - if not columns or "*" in columns - ) - if len(sources_without_schema) == 1: - table_name = sources_without_schema[0] - - if table_name not in self.scope.selected_sources: - return exp.to_identifier(table_name) - - node, _ = self.scope.selected_sources.get(table_name) - - if isinstance(node, exp.Query): - while node and node.alias != table_name: - node = node.parent - - node_alias = node.args.get("alias") - if node_alias: - return exp.to_identifier(node_alias.this) - - return exp.to_identifier(table_name) - - @property - def all_columns(self) -> t.Set[str]: - """All available columns of all sources in this scope""" - if self._all_columns is None: - self._all_columns = { - column for columns in self._get_all_source_columns().values() for column in columns - } - return self._all_columns - - def get_source_columns(self, name: str, only_visible: bool = False) -> t.Sequence[str]: - """Resolve the source columns for a given source `name`.""" - cache_key = (name, only_visible) - if cache_key not in self._get_source_columns_cache: - if name not in self.scope.sources: - raise OptimizeError(f"Unknown table: {name}") - - source = self.scope.sources[name] - - if isinstance(source, exp.Table): - columns = self.schema.column_names(source, only_visible) - elif isinstance(source, Scope) and isinstance( - source.expression, (exp.Values, exp.Unnest) - ): - columns = source.expression.named_selects - - # in bigquery, unnest structs are automatically scoped as tables, so you can - # directly select a struct field in a query. - # this handles the case where the unnest is statically defined. - if self.schema.dialect == "bigquery": - if source.expression.is_type(exp.DataType.Type.STRUCT): - for k in source.expression.type.expressions: # type: ignore - columns.append(k.name) - elif isinstance(source, Scope) and isinstance(source.expression, exp.SetOperation): - set_op = source.expression - - # BigQuery specific set operations modifiers, e.g INNER UNION ALL BY NAME - on_column_list = set_op.args.get("on") - - if on_column_list: - # The resulting columns are the columns in the ON clause: - # {INNER | LEFT | FULL} UNION ALL BY NAME ON (col1, col2, ...) - columns = [col.name for col in on_column_list] - elif set_op.side or set_op.kind: - side = set_op.side - kind = set_op.kind - - left = set_op.left.named_selects - right = set_op.right.named_selects - - # We use dict.fromkeys to deduplicate keys and maintain insertion order - if side == "LEFT": - columns = left - elif side == "FULL": - columns = list(dict.fromkeys(left + right)) - elif kind == "INNER": - columns = list(dict.fromkeys(left).keys() & dict.fromkeys(right).keys()) - else: - columns = set_op.named_selects - else: - select = seq_get(source.expression.selects, 0) - - if isinstance(select, exp.QueryTransform): - # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-qry-select-transform.html - schema = select.args.get("schema") - columns = [c.name for c in schema.expressions] if schema else ["key", "value"] - else: - columns = source.expression.named_selects - - node, _ = self.scope.selected_sources.get(name) or (None, None) - if isinstance(node, Scope): - column_aliases = node.expression.alias_column_names - elif isinstance(node, exp.Expression): - column_aliases = node.alias_column_names - else: - column_aliases = [] - - if column_aliases: - # If the source's columns are aliased, their aliases shadow the corresponding column names. - # This can be expensive if there are lots of columns, so only do this if column_aliases exist. - columns = [ - alias or name - for (name, alias) in itertools.zip_longest(columns, column_aliases) - ] - - self._get_source_columns_cache[cache_key] = columns - - return self._get_source_columns_cache[cache_key] - - def _get_all_source_columns(self) -> t.Dict[str, t.Sequence[str]]: - if self._source_columns is None: - self._source_columns = { - source_name: self.get_source_columns(source_name) - for source_name, source in itertools.chain( - self.scope.selected_sources.items(), self.scope.lateral_sources.items() - ) - } - return self._source_columns - - def _get_unambiguous_columns( - self, source_columns: t.Dict[str, t.Sequence[str]] - ) -> t.Mapping[str, str]: - """ - Find all the unambiguous columns in sources. - - Args: - source_columns: Mapping of names to source columns. - - Returns: - Mapping of column name to source name. - """ - if not source_columns: - return {} - - source_columns_pairs = list(source_columns.items()) - - first_table, first_columns = source_columns_pairs[0] - - if len(source_columns_pairs) == 1: - # Performance optimization - avoid copying first_columns if there is only one table. - return SingleValuedMapping(first_columns, first_table) - - unambiguous_columns = {col: first_table for col in first_columns} - all_columns = set(unambiguous_columns) - - for table, columns in source_columns_pairs[1:]: - unique = set(columns) - ambiguous = all_columns.intersection(unique) - all_columns.update(columns) - - for column in ambiguous: - unambiguous_columns.pop(column, None) - for column in unique.difference(ambiguous): - unambiguous_columns[column] = table - - return unambiguous_columns diff --git a/altimate_packages/sqlglot/optimizer/qualify_tables.py b/altimate_packages/sqlglot/optimizer/qualify_tables.py deleted file mode 100644 index 47cedd708..000000000 --- a/altimate_packages/sqlglot/optimizer/qualify_tables.py +++ /dev/null @@ -1,155 +0,0 @@ -from __future__ import annotations - -import itertools -import typing as t - -from sqlglot import alias, exp -from sqlglot.dialects.dialect import DialectType -from sqlglot.helper import csv_reader, name_sequence -from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.schema import Schema -from sqlglot.dialects.dialect import Dialect - -if t.TYPE_CHECKING: - from sqlglot._typing import E - - -def qualify_tables( - expression: E, - db: t.Optional[str | exp.Identifier] = None, - catalog: t.Optional[str | exp.Identifier] = None, - schema: t.Optional[Schema] = None, - infer_csv_schemas: bool = False, - dialect: DialectType = None, -) -> E: - """ - Rewrite sqlglot AST to have fully qualified tables. Join constructs such as - (t1 JOIN t2) AS t will be expanded into (SELECT * FROM t1 AS t1, t2 AS t2) AS t. - - Examples: - >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") - >>> qualify_tables(expression, db="db").sql() - 'SELECT 1 FROM db.tbl AS tbl' - >>> - >>> expression = sqlglot.parse_one("SELECT 1 FROM (t1 JOIN t2) AS t") - >>> qualify_tables(expression).sql() - 'SELECT 1 FROM (SELECT * FROM t1 AS t1, t2 AS t2) AS t' - - Args: - expression: Expression to qualify - db: Database name - catalog: Catalog name - schema: A schema to populate - infer_csv_schemas: Whether to scan READ_CSV calls in order to infer the CSVs' schemas. - dialect: The dialect to parse catalog and schema into. - - Returns: - The qualified expression. - """ - next_alias_name = name_sequence("_q_") - db = exp.parse_identifier(db, dialect=dialect) if db else None - catalog = exp.parse_identifier(catalog, dialect=dialect) if catalog else None - dialect = Dialect.get_or_raise(dialect) - - def _qualify(table: exp.Table) -> None: - if isinstance(table.this, exp.Identifier): - if db and not table.args.get("db"): - table.set("db", db.copy()) - if catalog and not table.args.get("catalog") and table.args.get("db"): - table.set("catalog", catalog.copy()) - - if (db or catalog) and not isinstance(expression, exp.Query): - for node in expression.walk(prune=lambda n: isinstance(n, exp.Query)): - if isinstance(node, exp.Table): - _qualify(node) - - for scope in traverse_scope(expression): - for derived_table in itertools.chain(scope.ctes, scope.derived_tables): - if isinstance(derived_table, exp.Subquery): - unnested = derived_table.unnest() - if isinstance(unnested, exp.Table): - joins = unnested.args.pop("joins", None) - derived_table.this.replace(exp.select("*").from_(unnested.copy(), copy=False)) - derived_table.this.set("joins", joins) - - if not derived_table.args.get("alias"): - alias_ = next_alias_name() - derived_table.set("alias", exp.TableAlias(this=exp.to_identifier(alias_))) - scope.rename_source(None, alias_) - - pivots = derived_table.args.get("pivots") - if pivots and not pivots[0].alias: - pivots[0].set("alias", exp.TableAlias(this=exp.to_identifier(next_alias_name()))) - - table_aliases = {} - - for name, source in scope.sources.items(): - if isinstance(source, exp.Table): - pivots = source.args.get("pivots") - if not source.alias: - # Don't add the pivot's alias to the pivoted table, use the table's name instead - if pivots and pivots[0].alias == name: - name = source.name - - # Mutates the source by attaching an alias to it - alias(source, name or source.name or next_alias_name(), copy=False, table=True) - - table_aliases[".".join(p.name for p in source.parts)] = exp.to_identifier( - source.alias - ) - - if pivots: - pivot = pivots[0] - if not pivot.alias: - pivot_alias = source.alias if pivot.unpivot else next_alias_name() - pivot.set("alias", exp.TableAlias(this=exp.to_identifier(pivot_alias))) - - # This case corresponds to a pivoted CTE, we don't want to qualify that - if isinstance(scope.sources.get(source.alias_or_name), Scope): - continue - - _qualify(source) - - if infer_csv_schemas and schema and isinstance(source.this, exp.ReadCSV): - with csv_reader(source.this) as reader: - header = next(reader) - columns = next(reader) - schema.add_table( - source, - {k: type(v).__name__ for k, v in zip(header, columns)}, - match_depth=False, - ) - elif isinstance(source, Scope) and source.is_udtf: - udtf = source.expression - table_alias = udtf.args.get("alias") or exp.TableAlias( - this=exp.to_identifier(next_alias_name()) - ) - udtf.set("alias", table_alias) - - if not table_alias.name: - table_alias.set("this", exp.to_identifier(next_alias_name())) - if isinstance(udtf, exp.Values) and not table_alias.columns: - column_aliases = dialect.generate_values_aliases(udtf) - table_alias.set("columns", column_aliases) - else: - for node in scope.walk(): - if ( - isinstance(node, exp.Table) - and not node.alias - and isinstance(node.parent, (exp.From, exp.Join)) - ): - # Mutates the table by attaching an alias to it - alias(node, node.name, copy=False, table=True) - - for column in scope.columns: - if column.db: - table_alias = table_aliases.get(".".join(p.name for p in column.parts[0:-1])) - - if table_alias: - for p in exp.COLUMN_PARTS[1:]: - column.set(p, None) - - column.set("table", table_alias.copy()) - - return expression diff --git a/altimate_packages/sqlglot/optimizer/scope.py b/altimate_packages/sqlglot/optimizer/scope.py deleted file mode 100644 index df6072c4b..000000000 --- a/altimate_packages/sqlglot/optimizer/scope.py +++ /dev/null @@ -1,904 +0,0 @@ -from __future__ import annotations - -import itertools -import logging -import typing as t -from collections import defaultdict -from enum import Enum, auto - -from sqlglot import exp -from sqlglot.errors import OptimizeError -from sqlglot.helper import ensure_collection, find_new_name, seq_get - -logger = logging.getLogger("sqlglot") - -TRAVERSABLES = (exp.Query, exp.DDL, exp.DML) - - -class ScopeType(Enum): - ROOT = auto() - SUBQUERY = auto() - DERIVED_TABLE = auto() - CTE = auto() - UNION = auto() - UDTF = auto() - - -class Scope: - """ - Selection scope. - - Attributes: - expression (exp.Select|exp.SetOperation): Root expression of this scope - sources (dict[str, exp.Table|Scope]): Mapping of source name to either - a Table expression or another Scope instance. For example: - SELECT * FROM x {"x": Table(this="x")} - SELECT * FROM x AS y {"y": Table(this="x")} - SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} - lateral_sources (dict[str, exp.Table|Scope]): Sources from laterals - For example: - SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c; - The LATERAL VIEW EXPLODE gets x as a source. - cte_sources (dict[str, Scope]): Sources from CTES - outer_columns (list[str]): If this is a derived table or CTE, and the outer query - defines a column list for the alias of this scope, this is that list of columns. - For example: - SELECT * FROM (SELECT ...) AS y(col1, col2) - The inner query would have `["col1", "col2"]` for its `outer_columns` - parent (Scope): Parent scope - scope_type (ScopeType): Type of this scope, relative to it's parent - subquery_scopes (list[Scope]): List of all child scopes for subqueries - cte_scopes (list[Scope]): List of all child scopes for CTEs - derived_table_scopes (list[Scope]): List of all child scopes for derived_tables - udtf_scopes (list[Scope]): List of all child scopes for user defined tabular functions - table_scopes (list[Scope]): derived_table_scopes + udtf_scopes, in the order that they're defined - union_scopes (list[Scope, Scope]): If this Scope is for a Union expression, this will be - a list of the left and right child scopes. - """ - - def __init__( - self, - expression, - sources=None, - outer_columns=None, - parent=None, - scope_type=ScopeType.ROOT, - lateral_sources=None, - cte_sources=None, - can_be_correlated=None, - ): - self.expression = expression - self.sources = sources or {} - self.lateral_sources = lateral_sources or {} - self.cte_sources = cte_sources or {} - self.sources.update(self.lateral_sources) - self.sources.update(self.cte_sources) - self.outer_columns = outer_columns or [] - self.parent = parent - self.scope_type = scope_type - self.subquery_scopes = [] - self.derived_table_scopes = [] - self.table_scopes = [] - self.cte_scopes = [] - self.union_scopes = [] - self.udtf_scopes = [] - self.can_be_correlated = can_be_correlated - self.clear_cache() - - def clear_cache(self): - self._collected = False - self._raw_columns = None - self._stars = None - self._derived_tables = None - self._udtfs = None - self._tables = None - self._ctes = None - self._subqueries = None - self._selected_sources = None - self._columns = None - self._external_columns = None - self._join_hints = None - self._pivots = None - self._references = None - self._semi_anti_join_tables = None - - def branch( - self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs - ): - """Branch from the current scope to a new, inner scope""" - return Scope( - expression=expression.unnest(), - sources=sources.copy() if sources else None, - parent=self, - scope_type=scope_type, - cte_sources={**self.cte_sources, **(cte_sources or {})}, - lateral_sources=lateral_sources.copy() if lateral_sources else None, - can_be_correlated=self.can_be_correlated - or scope_type in (ScopeType.SUBQUERY, ScopeType.UDTF), - **kwargs, - ) - - def _collect(self): - self._tables = [] - self._ctes = [] - self._subqueries = [] - self._derived_tables = [] - self._udtfs = [] - self._raw_columns = [] - self._stars = [] - self._join_hints = [] - self._semi_anti_join_tables = set() - - for node in self.walk(bfs=False): - if node is self.expression: - continue - - if isinstance(node, exp.Dot) and node.is_star: - self._stars.append(node) - elif isinstance(node, exp.Column): - if isinstance(node.this, exp.Star): - self._stars.append(node) - else: - self._raw_columns.append(node) - elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint): - parent = node.parent - if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join: - self._semi_anti_join_tables.add(node.alias_or_name) - - self._tables.append(node) - elif isinstance(node, exp.JoinHint): - self._join_hints.append(node) - elif isinstance(node, exp.UDTF): - self._udtfs.append(node) - elif isinstance(node, exp.CTE): - self._ctes.append(node) - elif _is_derived_table(node) and _is_from_or_join(node): - self._derived_tables.append(node) - elif isinstance(node, exp.UNWRAPPED_QUERIES): - self._subqueries.append(node) - - self._collected = True - - def _ensure_collected(self): - if not self._collected: - self._collect() - - def walk(self, bfs=True, prune=None): - return walk_in_scope(self.expression, bfs=bfs, prune=None) - - def find(self, *expression_types, bfs=True): - return find_in_scope(self.expression, expression_types, bfs=bfs) - - def find_all(self, *expression_types, bfs=True): - return find_all_in_scope(self.expression, expression_types, bfs=bfs) - - def replace(self, old, new): - """ - Replace `old` with `new`. - - This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. - - Args: - old (exp.Expression): old node - new (exp.Expression): new node - """ - old.replace(new) - self.clear_cache() - - @property - def tables(self): - """ - List of tables in this scope. - - Returns: - list[exp.Table]: tables - """ - self._ensure_collected() - return self._tables - - @property - def ctes(self): - """ - List of CTEs in this scope. - - Returns: - list[exp.CTE]: ctes - """ - self._ensure_collected() - return self._ctes - - @property - def derived_tables(self): - """ - List of derived tables in this scope. - - For example: - SELECT * FROM (SELECT ...) <- that's a derived table - - Returns: - list[exp.Subquery]: derived tables - """ - self._ensure_collected() - return self._derived_tables - - @property - def udtfs(self): - """ - List of "User Defined Tabular Functions" in this scope. - - Returns: - list[exp.UDTF]: UDTFs - """ - self._ensure_collected() - return self._udtfs - - @property - def subqueries(self): - """ - List of subqueries in this scope. - - For example: - SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery - - Returns: - list[exp.Select | exp.SetOperation]: subqueries - """ - self._ensure_collected() - return self._subqueries - - @property - def stars(self) -> t.List[exp.Column | exp.Dot]: - """ - List of star expressions (columns or dots) in this scope. - """ - self._ensure_collected() - return self._stars - - @property - def columns(self): - """ - List of columns in this scope. - - Returns: - list[exp.Column]: Column instances in this scope, plus any - Columns that reference this scope from correlated subqueries. - """ - if self._columns is None: - self._ensure_collected() - columns = self._raw_columns - - external_columns = [ - column - for scope in itertools.chain( - self.subquery_scopes, - self.udtf_scopes, - (dts for dts in self.derived_table_scopes if dts.can_be_correlated), - ) - for column in scope.external_columns - ] - - named_selects = set(self.expression.named_selects) - - self._columns = [] - for column in columns + external_columns: - ancestor = column.find_ancestor( - exp.Select, - exp.Qualify, - exp.Order, - exp.Having, - exp.Hint, - exp.Table, - exp.Star, - exp.Distinct, - ) - if ( - not ancestor - or column.table - or isinstance(ancestor, exp.Select) - or (isinstance(ancestor, exp.Table) and not isinstance(ancestor.this, exp.Func)) - or ( - isinstance(ancestor, (exp.Order, exp.Distinct)) - and ( - isinstance(ancestor.parent, (exp.Window, exp.WithinGroup)) - or column.name not in named_selects - ) - ) - or (isinstance(ancestor, exp.Star) and not column.arg_key == "except") - ): - self._columns.append(column) - - return self._columns - - @property - def selected_sources(self): - """ - Mapping of nodes and sources that are actually selected from in this scope. - - That is, all tables in a schema are selectable at any point. But a - table only becomes a selected source if it's included in a FROM or JOIN clause. - - Returns: - dict[str, (exp.Table|exp.Select, exp.Table|Scope)]: selected sources and nodes - """ - if self._selected_sources is None: - result = {} - - for name, node in self.references: - if name in self._semi_anti_join_tables: - # The RHS table of SEMI/ANTI joins shouldn't be collected as a - # selected source - continue - - if name in result: - raise OptimizeError(f"Alias already used: {name}") - if name in self.sources: - result[name] = (node, self.sources[name]) - - self._selected_sources = result - return self._selected_sources - - @property - def references(self) -> t.List[t.Tuple[str, exp.Expression]]: - if self._references is None: - self._references = [] - - for table in self.tables: - self._references.append((table.alias_or_name, table)) - for expression in itertools.chain(self.derived_tables, self.udtfs): - self._references.append( - ( - expression.alias, - expression if expression.args.get("pivots") else expression.unnest(), - ) - ) - - return self._references - - @property - def external_columns(self): - """ - Columns that appear to reference sources in outer scopes. - - Returns: - list[exp.Column]: Column instances that don't reference - sources in the current scope. - """ - if self._external_columns is None: - if isinstance(self.expression, exp.SetOperation): - left, right = self.union_scopes - self._external_columns = left.external_columns + right.external_columns - else: - self._external_columns = [ - c - for c in self.columns - if c.table not in self.selected_sources - and c.table not in self.semi_or_anti_join_tables - ] - - return self._external_columns - - @property - def unqualified_columns(self): - """ - Unqualified columns in the current scope. - - Returns: - list[exp.Column]: Unqualified columns - """ - return [c for c in self.columns if not c.table] - - @property - def join_hints(self): - """ - Hints that exist in the scope that reference tables - - Returns: - list[exp.JoinHint]: Join hints that are referenced within the scope - """ - if self._join_hints is None: - return [] - return self._join_hints - - @property - def pivots(self): - if not self._pivots: - self._pivots = [ - pivot for _, node in self.references for pivot in node.args.get("pivots") or [] - ] - - return self._pivots - - @property - def semi_or_anti_join_tables(self): - return self._semi_anti_join_tables or set() - - def source_columns(self, source_name): - """ - Get all columns in the current scope for a particular source. - - Args: - source_name (str): Name of the source - Returns: - list[exp.Column]: Column instances that reference `source_name` - """ - return [column for column in self.columns if column.table == source_name] - - @property - def is_subquery(self): - """Determine if this scope is a subquery""" - return self.scope_type == ScopeType.SUBQUERY - - @property - def is_derived_table(self): - """Determine if this scope is a derived table""" - return self.scope_type == ScopeType.DERIVED_TABLE - - @property - def is_union(self): - """Determine if this scope is a union""" - return self.scope_type == ScopeType.UNION - - @property - def is_cte(self): - """Determine if this scope is a common table expression""" - return self.scope_type == ScopeType.CTE - - @property - def is_root(self): - """Determine if this is the root scope""" - return self.scope_type == ScopeType.ROOT - - @property - def is_udtf(self): - """Determine if this scope is a UDTF (User Defined Table Function)""" - return self.scope_type == ScopeType.UDTF - - @property - def is_correlated_subquery(self): - """Determine if this scope is a correlated subquery""" - return bool(self.can_be_correlated and self.external_columns) - - def rename_source(self, old_name, new_name): - """Rename a source in this scope""" - columns = self.sources.pop(old_name or "", []) - self.sources[new_name] = columns - - def add_source(self, name, source): - """Add a source to this scope""" - self.sources[name] = source - self.clear_cache() - - def remove_source(self, name): - """Remove a source from this scope""" - self.sources.pop(name, None) - self.clear_cache() - - def __repr__(self): - return f"Scope<{self.expression.sql()}>" - - def traverse(self): - """ - Traverse the scope tree from this node. - - Yields: - Scope: scope instances in depth-first-search post-order - """ - stack = [self] - result = [] - while stack: - scope = stack.pop() - result.append(scope) - stack.extend( - itertools.chain( - scope.cte_scopes, - scope.union_scopes, - scope.table_scopes, - scope.subquery_scopes, - ) - ) - - yield from reversed(result) - - def ref_count(self): - """ - Count the number of times each scope in this tree is referenced. - - Returns: - dict[int, int]: Mapping of Scope instance ID to reference count - """ - scope_ref_count = defaultdict(lambda: 0) - - for scope in self.traverse(): - for _, source in scope.selected_sources.values(): - scope_ref_count[id(source)] += 1 - - return scope_ref_count - - -def traverse_scope(expression: exp.Expression) -> t.List[Scope]: - """ - Traverse an expression by its "scopes". - - "Scope" represents the current context of a Select statement. - - This is helpful for optimizing queries, where we need more information than - the expression tree itself. For example, we might care about the source - names within a subquery. Returns a list because a generator could result in - incomplete properties which is confusing. - - Examples: - >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") - >>> scopes = traverse_scope(expression) - >>> scopes[0].expression.sql(), list(scopes[0].sources) - ('SELECT a FROM x', ['x']) - >>> scopes[1].expression.sql(), list(scopes[1].sources) - ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) - - Args: - expression: Expression to traverse - - Returns: - A list of the created scope instances - """ - if isinstance(expression, TRAVERSABLES): - return list(_traverse_scope(Scope(expression))) - return [] - - -def build_scope(expression: exp.Expression) -> t.Optional[Scope]: - """ - Build a scope tree. - - Args: - expression: Expression to build the scope tree for. - - Returns: - The root scope - """ - return seq_get(traverse_scope(expression), -1) - - -def _traverse_scope(scope): - expression = scope.expression - - if isinstance(expression, exp.Select): - yield from _traverse_select(scope) - elif isinstance(expression, exp.SetOperation): - yield from _traverse_ctes(scope) - yield from _traverse_union(scope) - return - elif isinstance(expression, exp.Subquery): - if scope.is_root: - yield from _traverse_select(scope) - else: - yield from _traverse_subqueries(scope) - elif isinstance(expression, exp.Table): - yield from _traverse_tables(scope) - elif isinstance(expression, exp.UDTF): - yield from _traverse_udtfs(scope) - elif isinstance(expression, exp.DDL): - if isinstance(expression.expression, exp.Query): - yield from _traverse_ctes(scope) - yield from _traverse_scope(Scope(expression.expression, cte_sources=scope.cte_sources)) - return - elif isinstance(expression, exp.DML): - yield from _traverse_ctes(scope) - for query in find_all_in_scope(expression, exp.Query): - # This check ensures we don't yield the CTE/nested queries twice - if not isinstance(query.parent, (exp.CTE, exp.Subquery)): - yield from _traverse_scope(Scope(query, cte_sources=scope.cte_sources)) - return - else: - logger.warning("Cannot traverse scope %s with type '%s'", expression, type(expression)) - return - - yield scope - - -def _traverse_select(scope): - yield from _traverse_ctes(scope) - yield from _traverse_tables(scope) - yield from _traverse_subqueries(scope) - - -def _traverse_union(scope): - prev_scope = None - union_scope_stack = [scope] - expression_stack = [scope.expression.right, scope.expression.left] - - while expression_stack: - expression = expression_stack.pop() - union_scope = union_scope_stack[-1] - - new_scope = union_scope.branch( - expression, - outer_columns=union_scope.outer_columns, - scope_type=ScopeType.UNION, - ) - - if isinstance(expression, exp.SetOperation): - yield from _traverse_ctes(new_scope) - - union_scope_stack.append(new_scope) - expression_stack.extend([expression.right, expression.left]) - continue - - for scope in _traverse_scope(new_scope): - yield scope - - if prev_scope: - union_scope_stack.pop() - union_scope.union_scopes = [prev_scope, scope] - prev_scope = union_scope - - yield union_scope - else: - prev_scope = scope - - -def _traverse_ctes(scope): - sources = {} - - for cte in scope.ctes: - cte_name = cte.alias - - # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. - # thus the recursive scope is the first section of the union. - with_ = scope.expression.args.get("with") - if with_ and with_.recursive: - union = cte.this - - if isinstance(union, exp.SetOperation): - sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) - - child_scope = None - - for child_scope in _traverse_scope( - scope.branch( - cte.this, - cte_sources=sources, - outer_columns=cte.alias_column_names, - scope_type=ScopeType.CTE, - ) - ): - yield child_scope - - # append the final child_scope yielded - if child_scope: - sources[cte_name] = child_scope - scope.cte_scopes.append(child_scope) - - scope.sources.update(sources) - scope.cte_sources.update(sources) - - -def _is_derived_table(expression: exp.Subquery) -> bool: - """ - We represent (tbl1 JOIN tbl2) as a Subquery, but it's not really a "derived table", - as it doesn't introduce a new scope. If an alias is present, it shadows all names - under the Subquery, so that's one exception to this rule. - """ - return isinstance(expression, exp.Subquery) and bool( - expression.alias or isinstance(expression.this, exp.UNWRAPPED_QUERIES) - ) - - -def _is_from_or_join(expression: exp.Expression) -> bool: - """ - Determine if `expression` is the FROM or JOIN clause of a SELECT statement. - """ - parent = expression.parent - - # Subqueries can be arbitrarily nested - while isinstance(parent, exp.Subquery): - parent = parent.parent - - return isinstance(parent, (exp.From, exp.Join)) - - -def _traverse_tables(scope): - sources = {} - - # Traverse FROMs, JOINs, and LATERALs in the order they are defined - expressions = [] - from_ = scope.expression.args.get("from") - if from_: - expressions.append(from_.this) - - for join in scope.expression.args.get("joins") or []: - expressions.append(join.this) - - if isinstance(scope.expression, exp.Table): - expressions.append(scope.expression) - - expressions.extend(scope.expression.args.get("laterals") or []) - - for expression in expressions: - if isinstance(expression, exp.Final): - expression = expression.this - if isinstance(expression, exp.Table): - table_name = expression.name - source_name = expression.alias_or_name - - if table_name in scope.sources and not expression.db: - # This is a reference to a parent source (e.g. a CTE), not an actual table, unless - # it is pivoted, because then we get back a new table and hence a new source. - pivots = expression.args.get("pivots") - if pivots: - sources[pivots[0].alias] = expression - else: - sources[source_name] = scope.sources[table_name] - elif source_name in sources: - sources[find_new_name(sources, table_name)] = expression - else: - sources[source_name] = expression - - # Make sure to not include the joins twice - if expression is not scope.expression: - expressions.extend(join.this for join in expression.args.get("joins") or []) - - continue - - if not isinstance(expression, exp.DerivedTable): - continue - - if isinstance(expression, exp.UDTF): - lateral_sources = sources - scope_type = ScopeType.UDTF - scopes = scope.udtf_scopes - elif _is_derived_table(expression): - lateral_sources = None - scope_type = ScopeType.DERIVED_TABLE - scopes = scope.derived_table_scopes - expressions.extend(join.this for join in expression.args.get("joins") or []) - else: - # Makes sure we check for possible sources in nested table constructs - expressions.append(expression.this) - expressions.extend(join.this for join in expression.args.get("joins") or []) - continue - - child_scope = None - - for child_scope in _traverse_scope( - scope.branch( - expression, - lateral_sources=lateral_sources, - outer_columns=expression.alias_column_names, - scope_type=scope_type, - ) - ): - yield child_scope - - # Tables without aliases will be set as "" - # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. - # Until then, this means that only a single, unaliased derived table is allowed (rather, - # the latest one wins. - sources[expression.alias] = child_scope - - # append the final child_scope yielded - if child_scope: - scopes.append(child_scope) - scope.table_scopes.append(child_scope) - - scope.sources.update(sources) - - -def _traverse_subqueries(scope): - for subquery in scope.subqueries: - top = None - for child_scope in _traverse_scope(scope.branch(subquery, scope_type=ScopeType.SUBQUERY)): - yield child_scope - top = child_scope - scope.subquery_scopes.append(top) - - -def _traverse_udtfs(scope): - if isinstance(scope.expression, exp.Unnest): - expressions = scope.expression.expressions - elif isinstance(scope.expression, exp.Lateral): - expressions = [scope.expression.this] - else: - expressions = [] - - sources = {} - for expression in expressions: - if _is_derived_table(expression): - top = None - for child_scope in _traverse_scope( - scope.branch( - expression, - scope_type=ScopeType.SUBQUERY, - outer_columns=expression.alias_column_names, - ) - ): - yield child_scope - top = child_scope - sources[expression.alias] = child_scope - - scope.subquery_scopes.append(top) - - scope.sources.update(sources) - - -def walk_in_scope(expression, bfs=True, prune=None): - """ - Returns a generator object which visits all nodes in the syntrax tree, stopping at - nodes that start child scopes. - - Args: - expression (exp.Expression): - bfs (bool): if set to True the BFS traversal order will be applied, - otherwise the DFS traversal will be used instead. - prune ((node, parent, arg_key) -> bool): callable that returns True if - the generator should stop traversing this branch of the tree. - - Yields: - tuple[exp.Expression, Optional[exp.Expression], str]: node, parent, arg key - """ - # We'll use this variable to pass state into the dfs generator. - # Whenever we set it to True, we exclude a subtree from traversal. - crossed_scope_boundary = False - - for node in expression.walk( - bfs=bfs, prune=lambda n: crossed_scope_boundary or (prune and prune(n)) - ): - crossed_scope_boundary = False - - yield node - - if node is expression: - continue - if ( - isinstance(node, exp.CTE) - or ( - isinstance(node.parent, (exp.From, exp.Join, exp.Subquery)) - and (_is_derived_table(node) or isinstance(node, exp.UDTF)) - ) - or isinstance(node, exp.UNWRAPPED_QUERIES) - ): - crossed_scope_boundary = True - - if isinstance(node, (exp.Subquery, exp.UDTF)): - # The following args are not actually in the inner scope, so we should visit them - for key in ("joins", "laterals", "pivots"): - for arg in node.args.get(key) or []: - yield from walk_in_scope(arg, bfs=bfs) - - -def find_all_in_scope(expression, expression_types, bfs=True): - """ - Returns a generator object which visits all nodes in this scope and only yields those that - match at least one of the specified expression types. - - This does NOT traverse into subscopes. - - Args: - expression (exp.Expression): - expression_types (tuple[type]|type): the expression type(s) to match. - bfs (bool): True to use breadth-first search, False to use depth-first. - - Yields: - exp.Expression: nodes - """ - for expression in walk_in_scope(expression, bfs=bfs): - if isinstance(expression, tuple(ensure_collection(expression_types))): - yield expression - - -def find_in_scope(expression, expression_types, bfs=True): - """ - Returns the first node in this scope which matches at least one of the specified types. - - This does NOT traverse into subscopes. - - Args: - expression (exp.Expression): - expression_types (tuple[type]|type): the expression type(s) to match. - bfs (bool): True to use breadth-first search, False to use depth-first. - - Returns: - exp.Expression: the node which matches the criteria or None if no node matching - the criteria was found. - """ - return next(find_all_in_scope(expression, expression_types, bfs=bfs), None) diff --git a/altimate_packages/sqlglot/optimizer/simplify.py b/altimate_packages/sqlglot/optimizer/simplify.py deleted file mode 100644 index 90bbe2bb6..000000000 --- a/altimate_packages/sqlglot/optimizer/simplify.py +++ /dev/null @@ -1,1587 +0,0 @@ -from __future__ import annotations - -import datetime -import logging -import functools -import itertools -import typing as t -from collections import deque, defaultdict -from functools import reduce - -import sqlglot -from sqlglot import Dialect, exp -from sqlglot.helper import first, merge_ranges, while_changing -from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope - -if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import DialectType - - DateTruncBinaryTransform = t.Callable[ - [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] - ] - -logger = logging.getLogger("sqlglot") - -# Final means that an expression should not be simplified -FINAL = "final" - -# Value ranges for byte-sized signed/unsigned integers -TINYINT_MIN = -128 -TINYINT_MAX = 127 -UTINYINT_MIN = 0 -UTINYINT_MAX = 255 - - -class UnsupportedUnit(Exception): - pass - - -def simplify( - expression: exp.Expression, - constant_propagation: bool = False, - coalesce_simplification: bool = False, - dialect: DialectType = None, -): - """ - Rewrite sqlglot AST to simplify expressions. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one("TRUE AND TRUE") - >>> simplify(expression).sql() - 'TRUE' - - Args: - expression: expression to simplify - constant_propagation: whether the constant propagation rule should be used - coalesce_simplification: whether the simplify coalesce rule should be used. - This rule tries to remove coalesce functions, which can be useful in certain analyses but - can leave the query more verbose. - Returns: - sqlglot.Expression: simplified expression - """ - - dialect = Dialect.get_or_raise(dialect) - - def _simplify(expression): - pre_transformation_stack = [expression] - post_transformation_stack = [] - - while pre_transformation_stack: - node = pre_transformation_stack.pop() - - if node.meta.get(FINAL): - continue - - # group by expressions cannot be simplified, for example - # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 - # the projection must exactly match the group by key - group = node.args.get("group") - - if group and hasattr(node, "selects"): - groups = set(group.expressions) - group.meta[FINAL] = True - - for s in node.selects: - for n in s.walk(): - if n in groups: - s.meta[FINAL] = True - break - - having = node.args.get("having") - if having: - for n in having.walk(): - if n in groups: - having.meta[FINAL] = True - break - - parent = node.parent - root = node is expression - - new_node = rewrite_between(node) - new_node = uniq_sort(new_node, root) - new_node = absorb_and_eliminate(new_node, root) - new_node = simplify_concat(new_node) - new_node = simplify_conditionals(new_node) - - if constant_propagation: - new_node = propagate_constants(new_node, root) - - if new_node is not node: - node.replace(new_node) - - pre_transformation_stack.extend( - n for n in new_node.iter_expressions(reverse=True) if not n.meta.get(FINAL) - ) - post_transformation_stack.append((new_node, parent)) - - while post_transformation_stack: - node, parent = post_transformation_stack.pop() - root = node is expression - - # Resets parent, arg_key, index pointersโ€“ this is needed because some of the - # previous transformations mutate the AST, leading to an inconsistent state - for k, v in tuple(node.args.items()): - node.set(k, v) - - # Post-order transformations - new_node = simplify_not(node) - new_node = flatten(new_node) - new_node = simplify_connectors(new_node, root) - new_node = remove_complements(new_node, root) - - if coalesce_simplification: - new_node = simplify_coalesce(new_node, dialect) - - new_node.parent = parent - - new_node = simplify_literals(new_node, root) - new_node = simplify_equality(new_node) - new_node = simplify_parens(new_node) - new_node = simplify_datetrunc(new_node, dialect) - new_node = sort_comparison(new_node) - new_node = simplify_startswith(new_node) - - if new_node is not node: - node.replace(new_node) - - return new_node - - expression = while_changing(expression, _simplify) - remove_where_true(expression) - return expression - - -def catch(*exceptions): - """Decorator that ignores a simplification function if any of `exceptions` are raised""" - - def decorator(func): - def wrapped(expression, *args, **kwargs): - try: - return func(expression, *args, **kwargs) - except exceptions: - return expression - - return wrapped - - return decorator - - -def rewrite_between(expression: exp.Expression) -> exp.Expression: - """Rewrite x between y and z to x >= y AND x <= z. - - This is done because comparison simplification is only done on lt/lte/gt/gte. - """ - if isinstance(expression, exp.Between): - negate = isinstance(expression.parent, exp.Not) - - expression = exp.and_( - exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), - exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), - copy=False, - ) - - if negate: - expression = exp.paren(expression, copy=False) - - return expression - - -COMPLEMENT_COMPARISONS = { - exp.LT: exp.GTE, - exp.GT: exp.LTE, - exp.LTE: exp.GT, - exp.GTE: exp.LT, - exp.EQ: exp.NEQ, - exp.NEQ: exp.EQ, -} - -COMPLEMENT_SUBQUERY_PREDICATES = { - exp.All: exp.Any, - exp.Any: exp.All, -} - - -def simplify_not(expression): - """ - Demorgan's Law - NOT (x OR y) -> NOT x AND NOT y - NOT (x AND y) -> NOT x OR NOT y - """ - if isinstance(expression, exp.Not): - this = expression.this - if is_null(this): - return exp.null() - if this.__class__ in COMPLEMENT_COMPARISONS: - right = this.expression - complement_subquery_predicate = COMPLEMENT_SUBQUERY_PREDICATES.get(right.__class__) - if complement_subquery_predicate: - right = complement_subquery_predicate(this=right.this) - - return COMPLEMENT_COMPARISONS[this.__class__](this=this.this, expression=right) - if isinstance(this, exp.Paren): - condition = this.unnest() - if isinstance(condition, exp.And): - return exp.paren( - exp.or_( - exp.not_(condition.left, copy=False), - exp.not_(condition.right, copy=False), - copy=False, - ) - ) - if isinstance(condition, exp.Or): - return exp.paren( - exp.and_( - exp.not_(condition.left, copy=False), - exp.not_(condition.right, copy=False), - copy=False, - ) - ) - if is_null(condition): - return exp.null() - if always_true(this): - return exp.false() - if is_false(this): - return exp.true() - if isinstance(this, exp.Not): - # double negation - # NOT NOT x -> x - return this.this - return expression - - -def flatten(expression): - """ - A AND (B AND C) -> A AND B AND C - A OR (B OR C) -> A OR B OR C - """ - if isinstance(expression, exp.Connector): - for node in expression.args.values(): - child = node.unnest() - if isinstance(child, expression.__class__): - node.replace(child) - return expression - - -def simplify_connectors(expression, root=True): - def _simplify_connectors(expression, left, right): - if isinstance(expression, exp.And): - if is_false(left) or is_false(right): - return exp.false() - if is_zero(left) or is_zero(right): - return exp.false() - if is_null(left) or is_null(right): - return exp.null() - if always_true(left) and always_true(right): - return exp.true() - if always_true(left): - return right - if always_true(right): - return left - return _simplify_comparison(expression, left, right) - elif isinstance(expression, exp.Or): - if always_true(left) or always_true(right): - return exp.true() - if ( - (is_null(left) and is_null(right)) - or (is_null(left) and always_false(right)) - or (always_false(left) and is_null(right)) - ): - return exp.null() - if is_false(left): - return right - if is_false(right): - return left - return _simplify_comparison(expression, left, right, or_=True) - elif isinstance(expression, exp.Xor): - if left == right: - return exp.false() - - if isinstance(expression, exp.Connector): - return _flat_simplify(expression, _simplify_connectors, root) - return expression - - -LT_LTE = (exp.LT, exp.LTE) -GT_GTE = (exp.GT, exp.GTE) - -COMPARISONS = ( - *LT_LTE, - *GT_GTE, - exp.EQ, - exp.NEQ, - exp.Is, -) - -INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { - exp.LT: exp.GT, - exp.GT: exp.LT, - exp.LTE: exp.GTE, - exp.GTE: exp.LTE, -} - -NONDETERMINISTIC = (exp.Rand, exp.Randn) -AND_OR = (exp.And, exp.Or) - - -def _simplify_comparison(expression, left, right, or_=False): - if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): - ll, lr = left.args.values() - rl, rr = right.args.values() - - largs = {ll, lr} - rargs = {rl, rr} - - matching = largs & rargs - columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} - - if matching and columns: - try: - l = first(largs - columns) - r = first(rargs - columns) - except StopIteration: - return expression - - if l.is_number and r.is_number: - l = l.to_py() - r = r.to_py() - elif l.is_string and r.is_string: - l = l.name - r = r.name - else: - l = extract_date(l) - if not l: - return None - r = extract_date(r) - if not r: - return None - # python won't compare date and datetime, but many engines will upcast - l, r = cast_as_datetime(l), cast_as_datetime(r) - - for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): - if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): - return left if (av > bv if or_ else av <= bv) else right - if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): - return left if (av < bv if or_ else av >= bv) else right - - # we can't ever shortcut to true because the column could be null - if not or_: - if isinstance(a, exp.LT) and isinstance(b, GT_GTE): - if av <= bv: - return exp.false() - elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): - if av >= bv: - return exp.false() - elif isinstance(a, exp.EQ): - if isinstance(b, exp.LT): - return exp.false() if av >= bv else a - if isinstance(b, exp.LTE): - return exp.false() if av > bv else a - if isinstance(b, exp.GT): - return exp.false() if av <= bv else a - if isinstance(b, exp.GTE): - return exp.false() if av < bv else a - if isinstance(b, exp.NEQ): - return exp.false() if av == bv else a - return None - - -def remove_complements(expression, root=True): - """ - Removing complements. - - A AND NOT A -> FALSE - A OR NOT A -> TRUE - """ - if isinstance(expression, AND_OR) and (root or not expression.same_parent): - ops = set(expression.flatten()) - for op in ops: - if isinstance(op, exp.Not) and op.this in ops: - return exp.false() if isinstance(expression, exp.And) else exp.true() - - return expression - - -def uniq_sort(expression, root=True): - """ - Uniq and sort a connector. - - C AND A AND B AND B -> A AND B AND C - """ - if isinstance(expression, exp.Connector) and (root or not expression.same_parent): - flattened = tuple(expression.flatten()) - - if isinstance(expression, exp.Xor): - result_func = exp.xor - # Do not deduplicate XOR as A XOR A != A if A == True - deduped = None - arr = tuple((gen(e), e) for e in flattened) - else: - result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ - deduped = {gen(e): e for e in flattened} - arr = tuple(deduped.items()) - - # check if the operands are already sorted, if not sort them - # A AND C AND B -> A AND B AND C - for i, (sql, e) in enumerate(arr[1:]): - if sql < arr[i][0]: - expression = result_func(*(e for _, e in sorted(arr)), copy=False) - break - else: - # we didn't have to sort but maybe we need to dedup - if deduped and len(deduped) < len(flattened): - expression = result_func(*deduped.values(), copy=False) - - return expression - - -def absorb_and_eliminate(expression, root=True): - """ - absorption: - A AND (A OR B) -> A - A OR (A AND B) -> A - A AND (NOT A OR B) -> A AND B - A OR (NOT A AND B) -> A OR B - elimination: - (A AND B) OR (A AND NOT B) -> A - (A OR B) AND (A OR NOT B) -> A - """ - if isinstance(expression, AND_OR) and (root or not expression.same_parent): - kind = exp.Or if isinstance(expression, exp.And) else exp.And - - ops = tuple(expression.flatten()) - - # Initialize lookup tables: - # Set of all operands, used to find complements for absorption. - op_set = set() - # Sub-operands, used to find subsets for absorption. - subops = defaultdict(list) - # Pairs of complements, used for elimination. - pairs = defaultdict(list) - - # Populate the lookup tables - for op in ops: - op_set.add(op) - - if not isinstance(op, kind): - # In cases like: A OR (A AND B) - # Subop will be: ^ - subops[op].append({op}) - continue - - # In cases like: (A AND B) OR (A AND B AND C) - # Subops will be: ^ ^ - subset = set(op.flatten()) - for i in subset: - subops[i].append(subset) - - a, b = op.unnest_operands() - if isinstance(a, exp.Not): - pairs[frozenset((a.this, b))].append((op, b)) - if isinstance(b, exp.Not): - pairs[frozenset((a, b.this))].append((op, a)) - - for op in ops: - if not isinstance(op, kind): - continue - - a, b = op.unnest_operands() - - # Absorb - if isinstance(a, exp.Not) and a.this in op_set: - a.replace(exp.true() if kind == exp.And else exp.false()) - continue - if isinstance(b, exp.Not) and b.this in op_set: - b.replace(exp.true() if kind == exp.And else exp.false()) - continue - superset = set(op.flatten()) - if any(any(subset < superset for subset in subops[i]) for i in superset): - op.replace(exp.false() if kind == exp.And else exp.true()) - continue - - # Eliminate - for other, complement in pairs[frozenset((a, b))]: - op.replace(complement) - other.replace(complement) - - return expression - - -def propagate_constants(expression, root=True): - """ - Propagate constants for conjunctions in DNF: - - SELECT * FROM t WHERE a = b AND b = 5 becomes - SELECT * FROM t WHERE a = 5 AND b = 5 - - Reference: https://www.sqlite.org/optoverview.html - """ - - if ( - isinstance(expression, exp.And) - and (root or not expression.same_parent) - and sqlglot.optimizer.normalize.normalized(expression, dnf=True) - ): - constant_mapping = {} - for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): - if isinstance(expr, exp.EQ): - l, r = expr.left, expr.right - - # TODO: create a helper that can be used to detect nested literal expressions such - # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too - if isinstance(l, exp.Column) and isinstance(r, exp.Literal): - constant_mapping[l] = (id(l), r) - - if constant_mapping: - for column in find_all_in_scope(expression, exp.Column): - parent = column.parent - column_id, constant = constant_mapping.get(column) or (None, None) - if ( - column_id is not None - and id(column) != column_id - and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) - ): - column.replace(constant.copy()) - - return expression - - -INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { - exp.DateAdd: exp.Sub, - exp.DateSub: exp.Add, - exp.DatetimeAdd: exp.Sub, - exp.DatetimeSub: exp.Add, -} - -INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { - **INVERSE_DATE_OPS, - exp.Add: exp.Sub, - exp.Sub: exp.Add, -} - - -def _is_number(expression: exp.Expression) -> bool: - return expression.is_number - - -def _is_interval(expression: exp.Expression) -> bool: - return isinstance(expression, exp.Interval) and extract_interval(expression) is not None - - -@catch(ModuleNotFoundError, UnsupportedUnit) -def simplify_equality(expression: exp.Expression) -> exp.Expression: - """ - Use the subtraction and addition properties of equality to simplify expressions: - - x + 1 = 3 becomes x = 2 - - There are two binary operations in the above expression: + and = - Here's how we reference all the operands in the code below: - - l r - x + 1 = 3 - a b - """ - if isinstance(expression, COMPARISONS): - l, r = expression.left, expression.right - - if l.__class__ not in INVERSE_OPS: - return expression - - if r.is_number: - a_predicate = _is_number - b_predicate = _is_number - elif _is_date_literal(r): - a_predicate = _is_date_literal - b_predicate = _is_interval - else: - return expression - - if l.__class__ in INVERSE_DATE_OPS: - l = t.cast(exp.IntervalOp, l) - a = l.this - b = l.interval() - else: - l = t.cast(exp.Binary, l) - a, b = l.left, l.right - - if not a_predicate(a) and b_predicate(b): - pass - elif not a_predicate(b) and b_predicate(a): - a, b = b, a - else: - return expression - - return expression.__class__( - this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) - ) - return expression - - -def simplify_literals(expression, root=True): - if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): - return _flat_simplify(expression, _simplify_binary, root) - - if isinstance(expression, exp.Neg) and isinstance(expression.this, exp.Neg): - return expression.this.this - - if type(expression) in INVERSE_DATE_OPS: - return _simplify_binary(expression, expression.this, expression.interval()) or expression - - return expression - - -NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) - - -def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: - if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): - this = _simplify_integer_cast(expr.this) - else: - this = expr.this - - if isinstance(expr, exp.Cast) and this.is_int: - num = this.to_py() - - # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any - # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is - # engine-dependent - if ( - TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES - ) or ( - UTINYINT_MIN <= num <= UTINYINT_MAX - and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES - ): - return this - - return expr - - -def _simplify_binary(expression, a, b): - if isinstance(expression, COMPARISONS): - a = _simplify_integer_cast(a) - b = _simplify_integer_cast(b) - - if isinstance(expression, exp.Is): - if isinstance(b, exp.Not): - c = b.this - not_ = True - else: - c = b - not_ = False - - if is_null(c): - if isinstance(a, exp.Literal): - return exp.true() if not_ else exp.false() - if is_null(a): - return exp.false() if not_ else exp.true() - elif isinstance(expression, NULL_OK): - return None - elif is_null(a) or is_null(b): - return exp.null() - - if a.is_number and b.is_number: - num_a = a.to_py() - num_b = b.to_py() - - if isinstance(expression, exp.Add): - return exp.Literal.number(num_a + num_b) - if isinstance(expression, exp.Mul): - return exp.Literal.number(num_a * num_b) - - # We only simplify Sub, Div if a and b have the same parent because they're not associative - if isinstance(expression, exp.Sub): - return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None - if isinstance(expression, exp.Div): - # engines have differing int div behavior so intdiv is not safe - if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: - return None - return exp.Literal.number(num_a / num_b) - - boolean = eval_boolean(expression, num_a, num_b) - - if boolean: - return boolean - elif a.is_string and b.is_string: - boolean = eval_boolean(expression, a.this, b.this) - - if boolean: - return boolean - elif _is_date_literal(a) and isinstance(b, exp.Interval): - date, b = extract_date(a), extract_interval(b) - if date and b: - if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): - return date_literal(date + b, extract_type(a)) - if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): - return date_literal(date - b, extract_type(a)) - elif isinstance(a, exp.Interval) and _is_date_literal(b): - a, date = extract_interval(a), extract_date(b) - # you cannot subtract a date from an interval - if a and b and isinstance(expression, exp.Add): - return date_literal(a + date, extract_type(b)) - elif _is_date_literal(a) and _is_date_literal(b): - if isinstance(expression, exp.Predicate): - a, b = extract_date(a), extract_date(b) - boolean = eval_boolean(expression, a, b) - if boolean: - return boolean - - return None - - -def simplify_parens(expression): - if not isinstance(expression, exp.Paren): - return expression - - this = expression.this - parent = expression.parent - parent_is_predicate = isinstance(parent, exp.Predicate) - - if ( - not isinstance(this, exp.Select) - and not isinstance(parent, (exp.SubqueryPredicate, exp.Bracket)) - and ( - not isinstance(parent, (exp.Condition, exp.Binary)) - or isinstance(parent, exp.Paren) - or ( - not isinstance(this, exp.Binary) - and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) - ) - or (isinstance(this, exp.Predicate) and not parent_is_predicate) - or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) - or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) - or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) - ) - ): - return this - return expression - - -def _is_nonnull_constant(expression: exp.Expression) -> bool: - return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) - - -def _is_constant(expression: exp.Expression) -> bool: - return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) - - -def simplify_coalesce(expression: exp.Expression, dialect: DialectType) -> exp.Expression: - # COALESCE(x) -> x - if ( - isinstance(expression, exp.Coalesce) - and (not expression.expressions or _is_nonnull_constant(expression.this)) - # COALESCE is also used as a Spark partitioning hint - and not isinstance(expression.parent, exp.Hint) - ): - return expression.this - - # We can't convert `COALESCE(x, 1) = 2` into `NOT x IS NULL AND x = 2` for redshift, - # because they are not always equivalent. For example, if `x` is `NULL` and it comes - # from a table, then the result is `NULL`, despite `FALSE AND NULL` evaluating to `FALSE` - if dialect == "redshift": - return expression - - if not isinstance(expression, COMPARISONS): - return expression - - if isinstance(expression.left, exp.Coalesce): - coalesce = expression.left - other = expression.right - elif isinstance(expression.right, exp.Coalesce): - coalesce = expression.right - other = expression.left - else: - return expression - - # This transformation is valid for non-constants, - # but it really only does anything if they are both constants. - if not _is_constant(other): - return expression - - # Find the first constant arg - for arg_index, arg in enumerate(coalesce.expressions): - if _is_constant(arg): - break - else: - return expression - - coalesce.set("expressions", coalesce.expressions[:arg_index]) - - # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, - # since we already remove COALESCE at the top of this function. - coalesce = coalesce if coalesce.expressions else coalesce.this - - # This expression is more complex than when we started, but it will get simplified further - return exp.paren( - exp.or_( - exp.and_( - coalesce.is_(exp.null()).not_(copy=False), - expression.copy(), - copy=False, - ), - exp.and_( - coalesce.is_(exp.null()), - type(expression)(this=arg.copy(), expression=other.copy()), - copy=False, - ), - copy=False, - ) - ) - - -CONCATS = (exp.Concat, exp.DPipe) - - -def simplify_concat(expression): - """Reduces all groups that contain string literals by concatenating them.""" - if not isinstance(expression, CONCATS) or ( - # We can't reduce a CONCAT_WS call if we don't statically know the separator - isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string - ): - return expression - - if isinstance(expression, exp.ConcatWs): - sep_expr, *expressions = expression.expressions - sep = sep_expr.name - concat_type = exp.ConcatWs - args = {} - else: - expressions = expression.expressions - sep = "" - concat_type = exp.Concat - args = { - "safe": expression.args.get("safe"), - "coalesce": expression.args.get("coalesce"), - } - - new_args = [] - for is_string_group, group in itertools.groupby( - expressions or expression.flatten(), lambda e: e.is_string - ): - if is_string_group: - new_args.append(exp.Literal.string(sep.join(string.name for string in group))) - else: - new_args.extend(group) - - if len(new_args) == 1 and new_args[0].is_string: - return new_args[0] - - if concat_type is exp.ConcatWs: - new_args = [sep_expr] + new_args - elif isinstance(expression, exp.DPipe): - return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) - - return concat_type(expressions=new_args, **args) - - -def simplify_conditionals(expression): - """Simplifies expressions like IF, CASE if their condition is statically known.""" - if isinstance(expression, exp.Case): - this = expression.this - for case in expression.args["ifs"]: - cond = case.this - if this: - # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... - cond = cond.replace(this.pop().eq(cond)) - - if always_true(cond): - return case.args["true"] - - if always_false(cond): - case.pop() - if not expression.args["ifs"]: - return expression.args.get("default") or exp.null() - elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): - if always_true(expression.this): - return expression.args["true"] - if always_false(expression.this): - return expression.args.get("false") or exp.null() - - return expression - - -def simplify_startswith(expression: exp.Expression) -> exp.Expression: - """ - Reduces a prefix check to either TRUE or FALSE if both the string and the - prefix are statically known. - - Example: - >>> from sqlglot import parse_one - >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() - 'TRUE' - """ - if ( - isinstance(expression, exp.StartsWith) - and expression.this.is_string - and expression.expression.is_string - ): - return exp.convert(expression.name.startswith(expression.expression.name)) - - return expression - - -DateRange = t.Tuple[datetime.date, datetime.date] - - -def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: - """ - Get the date range for a DATE_TRUNC equality comparison: - - Example: - _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) - Returns: - tuple of [min, max) or None if a value can never be equal to `date` for `unit` - """ - floor = date_floor(date, unit, dialect) - - if date != floor: - # This will always be False, except for NULL values. - return None - - return floor, floor + interval(unit) - - -def _datetrunc_eq_expression( - left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] -) -> exp.Expression: - """Get the logical expression for a date range""" - return exp.and_( - left >= date_literal(drange[0], target_type), - left < date_literal(drange[1], target_type), - copy=False, - ) - - -def _datetrunc_eq( - left: exp.Expression, - date: datetime.date, - unit: str, - dialect: Dialect, - target_type: t.Optional[exp.DataType], -) -> t.Optional[exp.Expression]: - drange = _datetrunc_range(date, unit, dialect) - if not drange: - return None - - return _datetrunc_eq_expression(left, drange, target_type) - - -def _datetrunc_neq( - left: exp.Expression, - date: datetime.date, - unit: str, - dialect: Dialect, - target_type: t.Optional[exp.DataType], -) -> t.Optional[exp.Expression]: - drange = _datetrunc_range(date, unit, dialect) - if not drange: - return None - - return exp.and_( - left < date_literal(drange[0], target_type), - left >= date_literal(drange[1], target_type), - copy=False, - ) - - -DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { - exp.LT: lambda l, dt, u, d, t: l - < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), - exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), - exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), - exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), - exp.EQ: _datetrunc_eq, - exp.NEQ: _datetrunc_neq, -} -DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} -DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) - - -def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: - return isinstance(left, DATETRUNCS) and _is_date_literal(right) - - -@catch(ModuleNotFoundError, UnsupportedUnit) -def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: - """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" - comparison = expression.__class__ - - if isinstance(expression, DATETRUNCS): - this = expression.this - trunc_type = extract_type(this) - date = extract_date(this) - if date and expression.unit: - return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) - elif comparison not in DATETRUNC_COMPARISONS: - return expression - - if isinstance(expression, exp.Binary): - l, r = expression.left, expression.right - - if not _is_datetrunc_predicate(l, r): - return expression - - l = t.cast(exp.DateTrunc, l) - trunc_arg = l.this - unit = l.unit.name.lower() - date = extract_date(r) - - if not date: - return expression - - return ( - DATETRUNC_BINARY_COMPARISONS[comparison]( - trunc_arg, date, unit, dialect, extract_type(r) - ) - or expression - ) - - if isinstance(expression, exp.In): - l = expression.this - rs = expression.expressions - - if rs and all(_is_datetrunc_predicate(l, r) for r in rs): - l = t.cast(exp.DateTrunc, l) - unit = l.unit.name.lower() - - ranges = [] - for r in rs: - date = extract_date(r) - if not date: - return expression - drange = _datetrunc_range(date, unit, dialect) - if drange: - ranges.append(drange) - - if not ranges: - return expression - - ranges = merge_ranges(ranges) - target_type = extract_type(*rs) - - return exp.or_( - *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False - ) - - return expression - - -def sort_comparison(expression: exp.Expression) -> exp.Expression: - if expression.__class__ in COMPLEMENT_COMPARISONS: - l, r = expression.this, expression.expression - l_column = isinstance(l, exp.Column) - r_column = isinstance(r, exp.Column) - l_const = _is_constant(l) - r_const = _is_constant(r) - - if ( - (l_column and not r_column) - or (r_const and not l_const) - or isinstance(r, exp.SubqueryPredicate) - ): - return expression - if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): - return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( - this=r, expression=l - ) - return expression - - -# CROSS joins result in an empty table if the right table is empty. -# So we can only simplify certain types of joins to CROSS. -# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x -JOINS = { - ("", ""), - ("", "INNER"), - ("RIGHT", ""), - ("RIGHT", "OUTER"), -} - - -def remove_where_true(expression): - for where in expression.find_all(exp.Where): - if always_true(where.this): - where.pop() - for join in expression.find_all(exp.Join): - if ( - always_true(join.args.get("on")) - and not join.args.get("using") - and not join.args.get("method") - and (join.side, join.kind) in JOINS - ): - join.args["on"].pop() - join.set("side", None) - join.set("kind", "CROSS") - - -def always_true(expression): - return (isinstance(expression, exp.Boolean) and expression.this) or ( - isinstance(expression, exp.Literal) and not is_zero(expression) - ) - - -def always_false(expression): - return is_false(expression) or is_null(expression) or is_zero(expression) - - -def is_zero(expression): - return isinstance(expression, exp.Literal) and expression.to_py() == 0 - - -def is_complement(a, b): - return isinstance(b, exp.Not) and b.this == a - - -def is_false(a: exp.Expression) -> bool: - return type(a) is exp.Boolean and not a.this - - -def is_null(a: exp.Expression) -> bool: - return type(a) is exp.Null - - -def eval_boolean(expression, a, b): - if isinstance(expression, (exp.EQ, exp.Is)): - return boolean_literal(a == b) - if isinstance(expression, exp.NEQ): - return boolean_literal(a != b) - if isinstance(expression, exp.GT): - return boolean_literal(a > b) - if isinstance(expression, exp.GTE): - return boolean_literal(a >= b) - if isinstance(expression, exp.LT): - return boolean_literal(a < b) - if isinstance(expression, exp.LTE): - return boolean_literal(a <= b) - return None - - -def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: - if isinstance(value, datetime.datetime): - return value.date() - if isinstance(value, datetime.date): - return value - try: - return datetime.datetime.fromisoformat(value).date() - except ValueError: - return None - - -def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: - if isinstance(value, datetime.datetime): - return value - if isinstance(value, datetime.date): - return datetime.datetime(year=value.year, month=value.month, day=value.day) - try: - return datetime.datetime.fromisoformat(value) - except ValueError: - return None - - -def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: - if not value: - return None - if to.is_type(exp.DataType.Type.DATE): - return cast_as_date(value) - if to.is_type(*exp.DataType.TEMPORAL_TYPES): - return cast_as_datetime(value) - return None - - -def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: - if isinstance(cast, exp.Cast): - to = cast.to - elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): - to = exp.DataType.build(exp.DataType.Type.DATE) - else: - return None - - if isinstance(cast.this, exp.Literal): - value: t.Any = cast.this.name - elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): - value = extract_date(cast.this) - else: - return None - return cast_value(value, to) - - -def _is_date_literal(expression: exp.Expression) -> bool: - return extract_date(expression) is not None - - -def extract_interval(expression): - try: - n = int(expression.this.to_py()) - unit = expression.text("unit").lower() - return interval(unit, n) - except (UnsupportedUnit, ModuleNotFoundError, ValueError): - return None - - -def extract_type(*expressions): - target_type = None - for expression in expressions: - target_type = expression.to if isinstance(expression, exp.Cast) else expression.type - if target_type: - break - - return target_type - - -def date_literal(date, target_type=None): - if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): - target_type = ( - exp.DataType.Type.DATETIME - if isinstance(date, datetime.datetime) - else exp.DataType.Type.DATE - ) - - return exp.cast(exp.Literal.string(date), target_type) - - -def interval(unit: str, n: int = 1): - from dateutil.relativedelta import relativedelta - - if unit == "year": - return relativedelta(years=1 * n) - if unit == "quarter": - return relativedelta(months=3 * n) - if unit == "month": - return relativedelta(months=1 * n) - if unit == "week": - return relativedelta(weeks=1 * n) - if unit == "day": - return relativedelta(days=1 * n) - if unit == "hour": - return relativedelta(hours=1 * n) - if unit == "minute": - return relativedelta(minutes=1 * n) - if unit == "second": - return relativedelta(seconds=1 * n) - - raise UnsupportedUnit(f"Unsupported unit: {unit}") - - -def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: - if unit == "year": - return d.replace(month=1, day=1) - if unit == "quarter": - if d.month <= 3: - return d.replace(month=1, day=1) - elif d.month <= 6: - return d.replace(month=4, day=1) - elif d.month <= 9: - return d.replace(month=7, day=1) - else: - return d.replace(month=10, day=1) - if unit == "month": - return d.replace(month=d.month, day=1) - if unit == "week": - # Assuming week starts on Monday (0) and ends on Sunday (6) - return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) - if unit == "day": - return d - - raise UnsupportedUnit(f"Unsupported unit: {unit}") - - -def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: - floor = date_floor(d, unit, dialect) - - if floor == d: - return d - - return floor + interval(unit) - - -def boolean_literal(condition): - return exp.true() if condition else exp.false() - - -def _flat_simplify(expression, simplifier, root=True): - if root or not expression.same_parent: - operands = [] - queue = deque(expression.flatten(unnest=False)) - size = len(queue) - - while queue: - a = queue.popleft() - - for b in queue: - result = simplifier(expression, a, b) - - if result and result is not expression: - queue.remove(b) - queue.appendleft(result) - break - else: - operands.append(a) - - if len(operands) < size: - return functools.reduce( - lambda a, b: expression.__class__(this=a, expression=b), operands - ) - return expression - - -def gen(expression: t.Any, comments: bool = False) -> str: - """Simple pseudo sql generator for quickly generating sortable and uniq strings. - - Sorting and deduping sql is a necessary step for optimization. Calling the actual - generator is expensive so we have a bare minimum sql generator here. - - Args: - expression: the expression to convert into a SQL string. - comments: whether to include the expression's comments. - """ - return Gen().gen(expression, comments=comments) - - -class Gen: - def __init__(self): - self.stack = [] - self.sqls = [] - - def gen(self, expression: exp.Expression, comments: bool = False) -> str: - self.stack = [expression] - self.sqls.clear() - - while self.stack: - node = self.stack.pop() - - if isinstance(node, exp.Expression): - if comments and node.comments: - self.stack.append(f" /*{','.join(node.comments)}*/") - - exp_handler_name = f"{node.key}_sql" - - if hasattr(self, exp_handler_name): - getattr(self, exp_handler_name)(node) - elif isinstance(node, exp.Func): - self._function(node) - else: - key = node.key.upper() - self.stack.append(f"{key} " if self._args(node) else key) - elif type(node) is list: - for n in reversed(node): - if n is not None: - self.stack.extend((n, ",")) - if node: - self.stack.pop() - else: - if node is not None: - self.sqls.append(str(node)) - - return "".join(self.sqls) - - def add_sql(self, e: exp.Add) -> None: - self._binary(e, " + ") - - def alias_sql(self, e: exp.Alias) -> None: - self.stack.extend( - ( - e.args.get("alias"), - " AS ", - e.args.get("this"), - ) - ) - - def and_sql(self, e: exp.And) -> None: - self._binary(e, " AND ") - - def anonymous_sql(self, e: exp.Anonymous) -> None: - this = e.this - if isinstance(this, str): - name = this.upper() - elif isinstance(this, exp.Identifier): - name = this.this - name = f'"{name}"' if this.quoted else name.upper() - else: - raise ValueError( - f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." - ) - - self.stack.extend( - ( - ")", - e.expressions, - "(", - name, - ) - ) - - def between_sql(self, e: exp.Between) -> None: - self.stack.extend( - ( - e.args.get("high"), - " AND ", - e.args.get("low"), - " BETWEEN ", - e.this, - ) - ) - - def boolean_sql(self, e: exp.Boolean) -> None: - self.stack.append("TRUE" if e.this else "FALSE") - - def bracket_sql(self, e: exp.Bracket) -> None: - self.stack.extend( - ( - "]", - e.expressions, - "[", - e.this, - ) - ) - - def column_sql(self, e: exp.Column) -> None: - for p in reversed(e.parts): - self.stack.extend((p, ".")) - self.stack.pop() - - def datatype_sql(self, e: exp.DataType) -> None: - self._args(e, 1) - self.stack.append(f"{e.this.name} ") - - def div_sql(self, e: exp.Div) -> None: - self._binary(e, " / ") - - def dot_sql(self, e: exp.Dot) -> None: - self._binary(e, ".") - - def eq_sql(self, e: exp.EQ) -> None: - self._binary(e, " = ") - - def from_sql(self, e: exp.From) -> None: - self.stack.extend((e.this, "FROM ")) - - def gt_sql(self, e: exp.GT) -> None: - self._binary(e, " > ") - - def gte_sql(self, e: exp.GTE) -> None: - self._binary(e, " >= ") - - def identifier_sql(self, e: exp.Identifier) -> None: - self.stack.append(f'"{e.this}"' if e.quoted else e.this) - - def ilike_sql(self, e: exp.ILike) -> None: - self._binary(e, " ILIKE ") - - def in_sql(self, e: exp.In) -> None: - self.stack.append(")") - self._args(e, 1) - self.stack.extend( - ( - "(", - " IN ", - e.this, - ) - ) - - def intdiv_sql(self, e: exp.IntDiv) -> None: - self._binary(e, " DIV ") - - def is_sql(self, e: exp.Is) -> None: - self._binary(e, " IS ") - - def like_sql(self, e: exp.Like) -> None: - self._binary(e, " Like ") - - def literal_sql(self, e: exp.Literal) -> None: - self.stack.append(f"'{e.this}'" if e.is_string else e.this) - - def lt_sql(self, e: exp.LT) -> None: - self._binary(e, " < ") - - def lte_sql(self, e: exp.LTE) -> None: - self._binary(e, " <= ") - - def mod_sql(self, e: exp.Mod) -> None: - self._binary(e, " % ") - - def mul_sql(self, e: exp.Mul) -> None: - self._binary(e, " * ") - - def neg_sql(self, e: exp.Neg) -> None: - self._unary(e, "-") - - def neq_sql(self, e: exp.NEQ) -> None: - self._binary(e, " <> ") - - def not_sql(self, e: exp.Not) -> None: - self._unary(e, "NOT ") - - def null_sql(self, e: exp.Null) -> None: - self.stack.append("NULL") - - def or_sql(self, e: exp.Or) -> None: - self._binary(e, " OR ") - - def paren_sql(self, e: exp.Paren) -> None: - self.stack.extend( - ( - ")", - e.this, - "(", - ) - ) - - def sub_sql(self, e: exp.Sub) -> None: - self._binary(e, " - ") - - def subquery_sql(self, e: exp.Subquery) -> None: - self._args(e, 2) - alias = e.args.get("alias") - if alias: - self.stack.append(alias) - self.stack.extend((")", e.this, "(")) - - def table_sql(self, e: exp.Table) -> None: - self._args(e, 4) - alias = e.args.get("alias") - if alias: - self.stack.append(alias) - for p in reversed(e.parts): - self.stack.extend((p, ".")) - self.stack.pop() - - def tablealias_sql(self, e: exp.TableAlias) -> None: - columns = e.columns - - if columns: - self.stack.extend((")", columns, "(")) - - self.stack.extend((e.this, " AS ")) - - def var_sql(self, e: exp.Var) -> None: - self.stack.append(e.this) - - def _binary(self, e: exp.Binary, op: str) -> None: - self.stack.extend((e.expression, op, e.this)) - - def _unary(self, e: exp.Unary, op: str) -> None: - self.stack.extend((e.this, op)) - - def _function(self, e: exp.Func) -> None: - self.stack.extend( - ( - ")", - list(e.args.values()), - "(", - e.sql_name(), - ) - ) - - def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: - kvs = [] - arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types - - for k in arg_types or arg_types: - v = node.args.get(k) - - if v is not None: - kvs.append([f":{k}", v]) - if kvs: - self.stack.append(kvs) - return True - return False diff --git a/altimate_packages/sqlglot/optimizer/unnest_subqueries.py b/altimate_packages/sqlglot/optimizer/unnest_subqueries.py deleted file mode 100644 index eef97f1c0..000000000 --- a/altimate_packages/sqlglot/optimizer/unnest_subqueries.py +++ /dev/null @@ -1,302 +0,0 @@ -from sqlglot import exp -from sqlglot.helper import name_sequence -from sqlglot.optimizer.scope import ScopeType, find_in_scope, traverse_scope - - -def unnest_subqueries(expression): - """ - Rewrite sqlglot AST to convert some predicates with subqueries into joins. - - Convert scalar subqueries into cross joins. - Convert correlated or vectorized subqueries into a group by so it is not a many to many left join. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") - >>> unnest_subqueries(expression).sql() - 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' - - Args: - expression (sqlglot.Expression): expression to unnest - Returns: - sqlglot.Expression: unnested expression - """ - next_alias_name = name_sequence("_u_") - - for scope in traverse_scope(expression): - select = scope.expression - parent = select.parent_select - if not parent: - continue - if scope.external_columns: - decorrelate(select, parent, scope.external_columns, next_alias_name) - elif scope.scope_type == ScopeType.SUBQUERY: - unnest(select, parent, next_alias_name) - - return expression - - -def unnest(select, parent_select, next_alias_name): - if len(select.selects) > 1: - return - - predicate = select.find_ancestor(exp.Condition) - if ( - not predicate - or parent_select is not predicate.parent_select - or not parent_select.args.get("from") - ): - return - - if isinstance(select, exp.SetOperation): - select = exp.select(*select.selects).from_(select.subquery(next_alias_name())) - - alias = next_alias_name() - clause = predicate.find_ancestor(exp.Having, exp.Where, exp.Join) - - # This subquery returns a scalar and can just be converted to a cross join - if not isinstance(predicate, (exp.In, exp.Any)): - column = exp.column(select.selects[0].alias_or_name, alias) - - clause_parent_select = clause.parent_select if clause else None - - if (isinstance(clause, exp.Having) and clause_parent_select is parent_select) or ( - (not clause or clause_parent_select is not parent_select) - and ( - parent_select.args.get("group") - or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects) - ) - ): - column = exp.Max(this=column) - elif not isinstance(select.parent, exp.Subquery): - return - - _replace(select.parent, column) - parent_select.join(select, join_type="CROSS", join_alias=alias, copy=False) - return - - if select.find(exp.Limit, exp.Offset): - return - - if isinstance(predicate, exp.Any): - predicate = predicate.find_ancestor(exp.EQ) - - if not predicate or parent_select is not predicate.parent_select: - return - - column = _other_operand(predicate) - value = select.selects[0] - - join_key = exp.column(value.alias, alias) - join_key_not_null = join_key.is_(exp.null()).not_() - - if isinstance(clause, exp.Join): - _replace(predicate, exp.true()) - parent_select.where(join_key_not_null, copy=False) - else: - _replace(predicate, join_key_not_null) - - group = select.args.get("group") - - if group: - if {value.this} != set(group.expressions): - select = ( - exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias)) - .from_(select.subquery("_q", copy=False), copy=False) - .group_by(exp.column(value.alias, "_q"), copy=False) - ) - elif not find_in_scope(value.this, exp.AggFunc): - select = select.group_by(value.this, copy=False) - - parent_select.join( - select, - on=column.eq(join_key), - join_type="LEFT", - join_alias=alias, - copy=False, - ) - - -def decorrelate(select, parent_select, external_columns, next_alias_name): - where = select.args.get("where") - - if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): - return - - table_alias = next_alias_name() - keys = [] - - # for all external columns in the where statement, find the relevant predicate - # keys to convert it into a join - for column in external_columns: - if column.find_ancestor(exp.Where) is not where: - return - - predicate = column.find_ancestor(exp.Predicate) - - if not predicate or predicate.find_ancestor(exp.Where) is not where: - return - - if isinstance(predicate, exp.Binary): - key = ( - predicate.right - if any(node is column for node in predicate.left.walk()) - else predicate.left - ) - else: - return - - keys.append((key, column, predicate)) - - if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): - return - - is_subquery_projection = any( - node is select.parent - for node in map(lambda s: s.unalias(), parent_select.selects) - if isinstance(node, exp.Subquery) - ) - - value = select.selects[0] - key_aliases = {} - group_by = [] - - for key, _, predicate in keys: - # if we filter on the value of the subquery, it needs to be unique - if key == value.this: - key_aliases[key] = value.alias - group_by.append(key) - else: - if key not in key_aliases: - key_aliases[key] = next_alias_name() - # all predicates that are equalities must also be in the unique - # so that we don't do a many to many join - if isinstance(predicate, exp.EQ) and key not in group_by: - group_by.append(key) - - parent_predicate = select.find_ancestor(exp.Predicate) - - # if the value of the subquery is not an agg or a key, we need to collect it into an array - # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. - agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg - if not value.find(exp.AggFunc) and value.this not in group_by: - select.select( - exp.alias_(agg_func(this=value.this), value.alias, quoted=False), - append=False, - copy=False, - ) - - # exists queries should not have any selects as it only checks if there are any rows - # all selects will be added by the optimizer and only used for join keys - if isinstance(parent_predicate, exp.Exists): - select.args["expressions"] = [] - - for key, alias in key_aliases.items(): - if key in group_by: - # add all keys to the projections of the subquery - # so that we can use it as a join key - if isinstance(parent_predicate, exp.Exists) or key != value.this: - select.select(f"{key} AS {alias}", copy=False) - else: - select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) - - alias = exp.column(value.alias, table_alias) - other = _other_operand(parent_predicate) - op_type = type(parent_predicate.parent) if parent_predicate else None - - if isinstance(parent_predicate, exp.Exists): - alias = exp.column(list(key_aliases.values())[0], table_alias) - parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") - elif isinstance(parent_predicate, exp.All): - assert issubclass(op_type, exp.Binary) - predicate = op_type(this=other, expression=exp.column("_x")) - parent_predicate = _replace( - parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})" - ) - elif isinstance(parent_predicate, exp.Any): - assert issubclass(op_type, exp.Binary) - if value.this in group_by: - predicate = op_type(this=other, expression=alias) - parent_predicate = _replace(parent_predicate.parent, predicate) - else: - predicate = op_type(this=other, expression=exp.column("_x")) - parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})") - elif isinstance(parent_predicate, exp.In): - if value.this in group_by: - parent_predicate = _replace(parent_predicate, f"{other} = {alias}") - else: - parent_predicate = _replace( - parent_predicate, - f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", - ) - else: - if is_subquery_projection and select.parent.alias: - alias = exp.alias_(alias, select.parent.alias) - - # COUNT always returns 0 on empty datasets, so we need take that into consideration here - # by transforming all counts into 0 and using that as the coalesced value - if value.find(exp.Count): - - def remove_aggs(node): - if isinstance(node, exp.Count): - return exp.Literal.number(0) - elif isinstance(node, exp.AggFunc): - return exp.null() - return node - - alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)]) - - select.parent.replace(alias) - - for key, column, predicate in keys: - predicate.replace(exp.true()) - nested = exp.column(key_aliases[key], table_alias) - - if is_subquery_projection: - key.replace(nested) - if not isinstance(predicate, exp.EQ): - parent_select.where(predicate, copy=False) - continue - - if key in group_by: - key.replace(nested) - elif isinstance(predicate, exp.EQ): - parent_predicate = _replace( - parent_predicate, - f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", - ) - else: - key.replace(exp.to_identifier("_x")) - parent_predicate = _replace( - parent_predicate, - f"({parent_predicate} AND ARRAY_ANY({nested}, _x -> {predicate}))", - ) - - parent_select.join( - select.group_by(*group_by, copy=False), - on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], - join_type="LEFT", - join_alias=table_alias, - copy=False, - ) - - -def _replace(expression, condition): - return expression.replace(exp.condition(condition)) - - -def _other_operand(expression): - if isinstance(expression, exp.In): - return expression.this - - if isinstance(expression, (exp.Any, exp.All)): - return _other_operand(expression.parent) - - if isinstance(expression, exp.Binary): - return ( - expression.right - if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) - else expression.left - ) - - return None diff --git a/altimate_packages/sqlglot/parser.py b/altimate_packages/sqlglot/parser.py deleted file mode 100644 index 7e7038ef0..000000000 --- a/altimate_packages/sqlglot/parser.py +++ /dev/null @@ -1,8501 +0,0 @@ -from __future__ import annotations - -import logging -import typing as t -import itertools -from collections import defaultdict - -from sqlglot import exp -from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors -from sqlglot.helper import apply_index_offset, ensure_list, seq_get -from sqlglot.time import format_time -from sqlglot.tokens import Token, Tokenizer, TokenType -from sqlglot.trie import TrieResult, in_trie, new_trie - -if t.TYPE_CHECKING: - from sqlglot._typing import E, Lit - from sqlglot.dialects.dialect import Dialect, DialectType - - T = t.TypeVar("T") - TCeilFloor = t.TypeVar("TCeilFloor", exp.Ceil, exp.Floor) - -logger = logging.getLogger("sqlglot") - -OPTIONS_TYPE = t.Dict[str, t.Sequence[t.Union[t.Sequence[str], str]]] - - -def build_var_map(args: t.List) -> exp.StarMap | exp.VarMap: - if len(args) == 1 and args[0].is_star: - return exp.StarMap(this=args[0]) - - keys = [] - values = [] - for i in range(0, len(args), 2): - keys.append(args[i]) - values.append(args[i + 1]) - - return exp.VarMap(keys=exp.array(*keys, copy=False), values=exp.array(*values, copy=False)) - - -def build_like(args: t.List) -> exp.Escape | exp.Like: - like = exp.Like(this=seq_get(args, 1), expression=seq_get(args, 0)) - return exp.Escape(this=like, expression=seq_get(args, 2)) if len(args) > 2 else like - - -def binary_range_parser( - expr_type: t.Type[exp.Expression], reverse_args: bool = False -) -> t.Callable[[Parser, t.Optional[exp.Expression]], t.Optional[exp.Expression]]: - def _parse_binary_range( - self: Parser, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: - expression = self._parse_bitwise() - if reverse_args: - this, expression = expression, this - return self._parse_escape(self.expression(expr_type, this=this, expression=expression)) - - return _parse_binary_range - - -def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func: - # Default argument order is base, expression - this = seq_get(args, 0) - expression = seq_get(args, 1) - - if expression: - if not dialect.LOG_BASE_FIRST: - this, expression = expression, this - return exp.Log(this=this, expression=expression) - - return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this) - - -def build_hex(args: t.List, dialect: Dialect) -> exp.Hex | exp.LowerHex: - arg = seq_get(args, 0) - return exp.LowerHex(this=arg) if dialect.HEX_LOWERCASE else exp.Hex(this=arg) - - -def build_lower(args: t.List) -> exp.Lower | exp.Hex: - # LOWER(HEX(..)) can be simplified to LowerHex to simplify its transpilation - arg = seq_get(args, 0) - return exp.LowerHex(this=arg.this) if isinstance(arg, exp.Hex) else exp.Lower(this=arg) - - -def build_upper(args: t.List) -> exp.Upper | exp.Hex: - # UPPER(HEX(..)) can be simplified to Hex to simplify its transpilation - arg = seq_get(args, 0) - return exp.Hex(this=arg.this) if isinstance(arg, exp.Hex) else exp.Upper(this=arg) - - -def build_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]: - def _builder(args: t.List, dialect: Dialect) -> E: - expression = expr_type( - this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1)) - ) - if len(args) > 2 and expr_type is exp.JSONExtract: - expression.set("expressions", args[2:]) - - return expression - - return _builder - - -def build_mod(args: t.List) -> exp.Mod: - this = seq_get(args, 0) - expression = seq_get(args, 1) - - # Wrap the operands if they are binary nodes, e.g. MOD(a + 1, 7) -> (a + 1) % 7 - this = exp.Paren(this=this) if isinstance(this, exp.Binary) else this - expression = exp.Paren(this=expression) if isinstance(expression, exp.Binary) else expression - - return exp.Mod(this=this, expression=expression) - - -def build_pad(args: t.List, is_left: bool = True): - return exp.Pad( - this=seq_get(args, 0), - expression=seq_get(args, 1), - fill_pattern=seq_get(args, 2), - is_left=is_left, - ) - - -def build_array_constructor( - exp_class: t.Type[E], args: t.List, bracket_kind: TokenType, dialect: Dialect -) -> exp.Expression: - array_exp = exp_class(expressions=args) - - if exp_class == exp.Array and dialect.HAS_DISTINCT_ARRAY_CONSTRUCTORS: - array_exp.set("bracket_notation", bracket_kind == TokenType.L_BRACKET) - - return array_exp - - -def build_convert_timezone( - args: t.List, default_source_tz: t.Optional[str] = None -) -> t.Union[exp.ConvertTimezone, exp.Anonymous]: - if len(args) == 2: - source_tz = exp.Literal.string(default_source_tz) if default_source_tz else None - return exp.ConvertTimezone( - source_tz=source_tz, target_tz=seq_get(args, 0), timestamp=seq_get(args, 1) - ) - - return exp.ConvertTimezone.from_arg_list(args) - - -def build_trim(args: t.List, is_left: bool = True): - return exp.Trim( - this=seq_get(args, 0), - expression=seq_get(args, 1), - position="LEADING" if is_left else "TRAILING", - ) - - -def build_coalesce( - args: t.List, is_nvl: t.Optional[bool] = None, is_null: t.Optional[bool] = None -) -> exp.Coalesce: - return exp.Coalesce(this=seq_get(args, 0), expressions=args[1:], is_nvl=is_nvl, is_null=is_null) - - -def build_locate_strposition(args: t.List): - return exp.StrPosition( - this=seq_get(args, 1), - substr=seq_get(args, 0), - position=seq_get(args, 2), - ) - - -class _Parser(type): - def __new__(cls, clsname, bases, attrs): - klass = super().__new__(cls, clsname, bases, attrs) - - klass.SHOW_TRIE = new_trie(key.split(" ") for key in klass.SHOW_PARSERS) - klass.SET_TRIE = new_trie(key.split(" ") for key in klass.SET_PARSERS) - - return klass - - -class Parser(metaclass=_Parser): - """ - Parser consumes a list of tokens produced by the Tokenizer and produces a parsed syntax tree. - - Args: - error_level: The desired error level. - Default: ErrorLevel.IMMEDIATE - error_message_context: The amount of context to capture from a query string when displaying - the error message (in number of characters). - Default: 100 - max_errors: Maximum number of error messages to include in a raised ParseError. - This is only relevant if error_level is ErrorLevel.RAISE. - Default: 3 - """ - - FUNCTIONS: t.Dict[str, t.Callable] = { - **{name: func.from_arg_list for name, func in exp.FUNCTION_BY_NAME.items()}, - **dict.fromkeys(("COALESCE", "IFNULL", "NVL"), build_coalesce), - "ARRAY": lambda args, dialect: exp.Array(expressions=args), - "ARRAYAGG": lambda args, dialect: exp.ArrayAgg( - this=seq_get(args, 0), nulls_excluded=dialect.ARRAY_AGG_INCLUDES_NULLS is None or None - ), - "ARRAY_AGG": lambda args, dialect: exp.ArrayAgg( - this=seq_get(args, 0), nulls_excluded=dialect.ARRAY_AGG_INCLUDES_NULLS is None or None - ), - "CHAR": lambda args: exp.Chr(expressions=args), - "CHR": lambda args: exp.Chr(expressions=args), - "COUNT": lambda args: exp.Count(this=seq_get(args, 0), expressions=args[1:], big_int=True), - "CONCAT": lambda args, dialect: exp.Concat( - expressions=args, - safe=not dialect.STRICT_STRING_CONCAT, - coalesce=dialect.CONCAT_COALESCE, - ), - "CONCAT_WS": lambda args, dialect: exp.ConcatWs( - expressions=args, - safe=not dialect.STRICT_STRING_CONCAT, - coalesce=dialect.CONCAT_COALESCE, - ), - "CONVERT_TIMEZONE": build_convert_timezone, - "DATE_TO_DATE_STR": lambda args: exp.Cast( - this=seq_get(args, 0), - to=exp.DataType(this=exp.DataType.Type.TEXT), - ), - "GENERATE_DATE_ARRAY": lambda args: exp.GenerateDateArray( - start=seq_get(args, 0), - end=seq_get(args, 1), - step=seq_get(args, 2) or exp.Interval(this=exp.Literal.string(1), unit=exp.var("DAY")), - ), - "GLOB": lambda args: exp.Glob(this=seq_get(args, 1), expression=seq_get(args, 0)), - "HEX": build_hex, - "JSON_EXTRACT": build_extract_json_with_path(exp.JSONExtract), - "JSON_EXTRACT_SCALAR": build_extract_json_with_path(exp.JSONExtractScalar), - "JSON_EXTRACT_PATH_TEXT": build_extract_json_with_path(exp.JSONExtractScalar), - "LIKE": build_like, - "LOG": build_logarithm, - "LOG2": lambda args: exp.Log(this=exp.Literal.number(2), expression=seq_get(args, 0)), - "LOG10": lambda args: exp.Log(this=exp.Literal.number(10), expression=seq_get(args, 0)), - "LOWER": build_lower, - "LPAD": lambda args: build_pad(args), - "LEFTPAD": lambda args: build_pad(args), - "LTRIM": lambda args: build_trim(args), - "MOD": build_mod, - "RIGHTPAD": lambda args: build_pad(args, is_left=False), - "RPAD": lambda args: build_pad(args, is_left=False), - "RTRIM": lambda args: build_trim(args, is_left=False), - "SCOPE_RESOLUTION": lambda args: exp.ScopeResolution(expression=seq_get(args, 0)) - if len(args) != 2 - else exp.ScopeResolution(this=seq_get(args, 0), expression=seq_get(args, 1)), - "STRPOS": exp.StrPosition.from_arg_list, - "CHARINDEX": lambda args: build_locate_strposition(args), - "INSTR": exp.StrPosition.from_arg_list, - "LOCATE": lambda args: build_locate_strposition(args), - "TIME_TO_TIME_STR": lambda args: exp.Cast( - this=seq_get(args, 0), - to=exp.DataType(this=exp.DataType.Type.TEXT), - ), - "TO_HEX": build_hex, - "TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring( - this=exp.Cast( - this=seq_get(args, 0), - to=exp.DataType(this=exp.DataType.Type.TEXT), - ), - start=exp.Literal.number(1), - length=exp.Literal.number(10), - ), - "UNNEST": lambda args: exp.Unnest(expressions=ensure_list(seq_get(args, 0))), - "UPPER": build_upper, - "VAR_MAP": build_var_map, - } - - NO_PAREN_FUNCTIONS = { - TokenType.CURRENT_DATE: exp.CurrentDate, - TokenType.CURRENT_DATETIME: exp.CurrentDate, - TokenType.CURRENT_TIME: exp.CurrentTime, - TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp, - TokenType.CURRENT_USER: exp.CurrentUser, - } - - STRUCT_TYPE_TOKENS = { - TokenType.NESTED, - TokenType.OBJECT, - TokenType.STRUCT, - TokenType.UNION, - } - - NESTED_TYPE_TOKENS = { - TokenType.ARRAY, - TokenType.LIST, - TokenType.LOWCARDINALITY, - TokenType.MAP, - TokenType.NULLABLE, - TokenType.RANGE, - *STRUCT_TYPE_TOKENS, - } - - ENUM_TYPE_TOKENS = { - TokenType.DYNAMIC, - TokenType.ENUM, - TokenType.ENUM8, - TokenType.ENUM16, - } - - AGGREGATE_TYPE_TOKENS = { - TokenType.AGGREGATEFUNCTION, - TokenType.SIMPLEAGGREGATEFUNCTION, - } - - TYPE_TOKENS = { - TokenType.BIT, - TokenType.BOOLEAN, - TokenType.TINYINT, - TokenType.UTINYINT, - TokenType.SMALLINT, - TokenType.USMALLINT, - TokenType.INT, - TokenType.UINT, - TokenType.BIGINT, - TokenType.UBIGINT, - TokenType.INT128, - TokenType.UINT128, - TokenType.INT256, - TokenType.UINT256, - TokenType.MEDIUMINT, - TokenType.UMEDIUMINT, - TokenType.FIXEDSTRING, - TokenType.FLOAT, - TokenType.DOUBLE, - TokenType.UDOUBLE, - TokenType.CHAR, - TokenType.NCHAR, - TokenType.VARCHAR, - TokenType.NVARCHAR, - TokenType.BPCHAR, - TokenType.TEXT, - TokenType.MEDIUMTEXT, - TokenType.LONGTEXT, - TokenType.BLOB, - TokenType.MEDIUMBLOB, - TokenType.LONGBLOB, - TokenType.BINARY, - TokenType.VARBINARY, - TokenType.JSON, - TokenType.JSONB, - TokenType.INTERVAL, - TokenType.TINYBLOB, - TokenType.TINYTEXT, - TokenType.TIME, - TokenType.TIMETZ, - TokenType.TIMESTAMP, - TokenType.TIMESTAMP_S, - TokenType.TIMESTAMP_MS, - TokenType.TIMESTAMP_NS, - TokenType.TIMESTAMPTZ, - TokenType.TIMESTAMPLTZ, - TokenType.TIMESTAMPNTZ, - TokenType.DATETIME, - TokenType.DATETIME2, - TokenType.DATETIME64, - TokenType.SMALLDATETIME, - TokenType.DATE, - TokenType.DATE32, - TokenType.INT4RANGE, - TokenType.INT4MULTIRANGE, - TokenType.INT8RANGE, - TokenType.INT8MULTIRANGE, - TokenType.NUMRANGE, - TokenType.NUMMULTIRANGE, - TokenType.TSRANGE, - TokenType.TSMULTIRANGE, - TokenType.TSTZRANGE, - TokenType.TSTZMULTIRANGE, - TokenType.DATERANGE, - TokenType.DATEMULTIRANGE, - TokenType.DECIMAL, - TokenType.DECIMAL32, - TokenType.DECIMAL64, - TokenType.DECIMAL128, - TokenType.DECIMAL256, - TokenType.UDECIMAL, - TokenType.BIGDECIMAL, - TokenType.UUID, - TokenType.GEOGRAPHY, - TokenType.GEOMETRY, - TokenType.POINT, - TokenType.RING, - TokenType.LINESTRING, - TokenType.MULTILINESTRING, - TokenType.POLYGON, - TokenType.MULTIPOLYGON, - TokenType.HLLSKETCH, - TokenType.HSTORE, - TokenType.PSEUDO_TYPE, - TokenType.SUPER, - TokenType.SERIAL, - TokenType.SMALLSERIAL, - TokenType.BIGSERIAL, - TokenType.XML, - TokenType.YEAR, - TokenType.USERDEFINED, - TokenType.MONEY, - TokenType.SMALLMONEY, - TokenType.ROWVERSION, - TokenType.IMAGE, - TokenType.VARIANT, - TokenType.VECTOR, - TokenType.VOID, - TokenType.OBJECT, - TokenType.OBJECT_IDENTIFIER, - TokenType.INET, - TokenType.IPADDRESS, - TokenType.IPPREFIX, - TokenType.IPV4, - TokenType.IPV6, - TokenType.UNKNOWN, - TokenType.NOTHING, - TokenType.NULL, - TokenType.NAME, - TokenType.TDIGEST, - TokenType.DYNAMIC, - *ENUM_TYPE_TOKENS, - *NESTED_TYPE_TOKENS, - *AGGREGATE_TYPE_TOKENS, - } - - SIGNED_TO_UNSIGNED_TYPE_TOKEN = { - TokenType.BIGINT: TokenType.UBIGINT, - TokenType.INT: TokenType.UINT, - TokenType.MEDIUMINT: TokenType.UMEDIUMINT, - TokenType.SMALLINT: TokenType.USMALLINT, - TokenType.TINYINT: TokenType.UTINYINT, - TokenType.DECIMAL: TokenType.UDECIMAL, - TokenType.DOUBLE: TokenType.UDOUBLE, - } - - SUBQUERY_PREDICATES = { - TokenType.ANY: exp.Any, - TokenType.ALL: exp.All, - TokenType.EXISTS: exp.Exists, - TokenType.SOME: exp.Any, - } - - RESERVED_TOKENS = { - *Tokenizer.SINGLE_TOKENS.values(), - TokenType.SELECT, - } - {TokenType.IDENTIFIER} - - DB_CREATABLES = { - TokenType.DATABASE, - TokenType.DICTIONARY, - TokenType.FILE_FORMAT, - TokenType.MODEL, - TokenType.NAMESPACE, - TokenType.SCHEMA, - TokenType.SEQUENCE, - TokenType.SINK, - TokenType.SOURCE, - TokenType.STAGE, - TokenType.STORAGE_INTEGRATION, - TokenType.STREAMLIT, - TokenType.TABLE, - TokenType.TAG, - TokenType.VIEW, - TokenType.WAREHOUSE, - } - - CREATABLES = { - TokenType.COLUMN, - TokenType.CONSTRAINT, - TokenType.FOREIGN_KEY, - TokenType.FUNCTION, - TokenType.INDEX, - TokenType.PROCEDURE, - *DB_CREATABLES, - } - - ALTERABLES = { - TokenType.INDEX, - TokenType.TABLE, - TokenType.VIEW, - } - - # Tokens that can represent identifiers - ID_VAR_TOKENS = { - TokenType.ALL, - TokenType.ATTACH, - TokenType.VAR, - TokenType.ANTI, - TokenType.APPLY, - TokenType.ASC, - TokenType.ASOF, - TokenType.AUTO_INCREMENT, - TokenType.BEGIN, - TokenType.BPCHAR, - TokenType.CACHE, - TokenType.CASE, - TokenType.COLLATE, - TokenType.COMMAND, - TokenType.COMMENT, - TokenType.COMMIT, - TokenType.CONSTRAINT, - TokenType.COPY, - TokenType.CUBE, - TokenType.CURRENT_SCHEMA, - TokenType.DEFAULT, - TokenType.DELETE, - TokenType.DESC, - TokenType.DESCRIBE, - TokenType.DETACH, - TokenType.DICTIONARY, - TokenType.DIV, - TokenType.END, - TokenType.EXECUTE, - TokenType.EXPORT, - TokenType.ESCAPE, - TokenType.FALSE, - TokenType.FIRST, - TokenType.FILTER, - TokenType.FINAL, - TokenType.FORMAT, - TokenType.FULL, - TokenType.GET, - TokenType.IDENTIFIER, - TokenType.IS, - TokenType.ISNULL, - TokenType.INTERVAL, - TokenType.KEEP, - TokenType.KILL, - TokenType.LEFT, - TokenType.LIMIT, - TokenType.LOAD, - TokenType.MERGE, - TokenType.NATURAL, - TokenType.NEXT, - TokenType.OFFSET, - TokenType.OPERATOR, - TokenType.ORDINALITY, - TokenType.OVERLAPS, - TokenType.OVERWRITE, - TokenType.PARTITION, - TokenType.PERCENT, - TokenType.PIVOT, - TokenType.PRAGMA, - TokenType.PUT, - TokenType.RANGE, - TokenType.RECURSIVE, - TokenType.REFERENCES, - TokenType.REFRESH, - TokenType.RENAME, - TokenType.REPLACE, - TokenType.RIGHT, - TokenType.ROLLUP, - TokenType.ROW, - TokenType.ROWS, - TokenType.SEMI, - TokenType.SET, - TokenType.SETTINGS, - TokenType.SHOW, - TokenType.TEMPORARY, - TokenType.TOP, - TokenType.TRUE, - TokenType.TRUNCATE, - TokenType.UNIQUE, - TokenType.UNNEST, - TokenType.UNPIVOT, - TokenType.UPDATE, - TokenType.USE, - TokenType.VOLATILE, - TokenType.WINDOW, - *CREATABLES, - *SUBQUERY_PREDICATES, - *TYPE_TOKENS, - *NO_PAREN_FUNCTIONS, - } - ID_VAR_TOKENS.remove(TokenType.UNION) - - TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { - TokenType.ANTI, - TokenType.APPLY, - TokenType.ASOF, - TokenType.FULL, - TokenType.LEFT, - TokenType.LOCK, - TokenType.NATURAL, - TokenType.RIGHT, - TokenType.SEMI, - TokenType.WINDOW, - } - - ALIAS_TOKENS = ID_VAR_TOKENS - - COLON_PLACEHOLDER_TOKENS = ID_VAR_TOKENS - - ARRAY_CONSTRUCTORS = { - "ARRAY": exp.Array, - "LIST": exp.List, - } - - COMMENT_TABLE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.IS} - - UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} - - TRIM_TYPES = {"LEADING", "TRAILING", "BOTH"} - - FUNC_TOKENS = { - TokenType.COLLATE, - TokenType.COMMAND, - TokenType.CURRENT_DATE, - TokenType.CURRENT_DATETIME, - TokenType.CURRENT_SCHEMA, - TokenType.CURRENT_TIMESTAMP, - TokenType.CURRENT_TIME, - TokenType.CURRENT_USER, - TokenType.FILTER, - TokenType.FIRST, - TokenType.FORMAT, - TokenType.GET, - TokenType.GLOB, - TokenType.IDENTIFIER, - TokenType.INDEX, - TokenType.ISNULL, - TokenType.ILIKE, - TokenType.INSERT, - TokenType.LIKE, - TokenType.MERGE, - TokenType.NEXT, - TokenType.OFFSET, - TokenType.PRIMARY_KEY, - TokenType.RANGE, - TokenType.REPLACE, - TokenType.RLIKE, - TokenType.ROW, - TokenType.UNNEST, - TokenType.VAR, - TokenType.LEFT, - TokenType.RIGHT, - TokenType.SEQUENCE, - TokenType.DATE, - TokenType.DATETIME, - TokenType.TABLE, - TokenType.TIMESTAMP, - TokenType.TIMESTAMPTZ, - TokenType.TRUNCATE, - TokenType.WINDOW, - TokenType.XOR, - *TYPE_TOKENS, - *SUBQUERY_PREDICATES, - } - - CONJUNCTION: t.Dict[TokenType, t.Type[exp.Expression]] = { - TokenType.AND: exp.And, - } - - ASSIGNMENT: t.Dict[TokenType, t.Type[exp.Expression]] = { - TokenType.COLON_EQ: exp.PropertyEQ, - } - - DISJUNCTION: t.Dict[TokenType, t.Type[exp.Expression]] = { - TokenType.OR: exp.Or, - } - - EQUALITY = { - TokenType.EQ: exp.EQ, - TokenType.NEQ: exp.NEQ, - TokenType.NULLSAFE_EQ: exp.NullSafeEQ, - } - - COMPARISON = { - TokenType.GT: exp.GT, - TokenType.GTE: exp.GTE, - TokenType.LT: exp.LT, - TokenType.LTE: exp.LTE, - } - - BITWISE = { - TokenType.AMP: exp.BitwiseAnd, - TokenType.CARET: exp.BitwiseXor, - TokenType.PIPE: exp.BitwiseOr, - } - - TERM = { - TokenType.DASH: exp.Sub, - TokenType.PLUS: exp.Add, - TokenType.MOD: exp.Mod, - TokenType.COLLATE: exp.Collate, - } - - FACTOR = { - TokenType.DIV: exp.IntDiv, - TokenType.LR_ARROW: exp.Distance, - TokenType.SLASH: exp.Div, - TokenType.STAR: exp.Mul, - } - - EXPONENT: t.Dict[TokenType, t.Type[exp.Expression]] = {} - - TIMES = { - TokenType.TIME, - TokenType.TIMETZ, - } - - TIMESTAMPS = { - TokenType.TIMESTAMP, - TokenType.TIMESTAMPNTZ, - TokenType.TIMESTAMPTZ, - TokenType.TIMESTAMPLTZ, - *TIMES, - } - - SET_OPERATIONS = { - TokenType.UNION, - TokenType.INTERSECT, - TokenType.EXCEPT, - } - - JOIN_METHODS = { - TokenType.ASOF, - TokenType.NATURAL, - TokenType.POSITIONAL, - } - - JOIN_SIDES = { - TokenType.LEFT, - TokenType.RIGHT, - TokenType.FULL, - } - - JOIN_KINDS = { - TokenType.ANTI, - TokenType.CROSS, - TokenType.INNER, - TokenType.OUTER, - TokenType.SEMI, - TokenType.STRAIGHT_JOIN, - } - - JOIN_HINTS: t.Set[str] = set() - - LAMBDAS = { - TokenType.ARROW: lambda self, expressions: self.expression( - exp.Lambda, - this=self._replace_lambda( - self._parse_assignment(), - expressions, - ), - expressions=expressions, - ), - TokenType.FARROW: lambda self, expressions: self.expression( - exp.Kwarg, - this=exp.var(expressions[0].name), - expression=self._parse_assignment(), - ), - } - - COLUMN_OPERATORS = { - TokenType.DOT: None, - TokenType.DOTCOLON: lambda self, this, to: self.expression( - exp.JSONCast, - this=this, - to=to, - ), - TokenType.DCOLON: lambda self, this, to: self.expression( - exp.Cast if self.STRICT_CAST else exp.TryCast, - this=this, - to=to, - ), - TokenType.ARROW: lambda self, this, path: self.expression( - exp.JSONExtract, - this=this, - expression=self.dialect.to_json_path(path), - only_json_types=self.JSON_ARROWS_REQUIRE_JSON_TYPE, - ), - TokenType.DARROW: lambda self, this, path: self.expression( - exp.JSONExtractScalar, - this=this, - expression=self.dialect.to_json_path(path), - only_json_types=self.JSON_ARROWS_REQUIRE_JSON_TYPE, - ), - TokenType.HASH_ARROW: lambda self, this, path: self.expression( - exp.JSONBExtract, - this=this, - expression=path, - ), - TokenType.DHASH_ARROW: lambda self, this, path: self.expression( - exp.JSONBExtractScalar, - this=this, - expression=path, - ), - TokenType.PLACEHOLDER: lambda self, this, key: self.expression( - exp.JSONBContains, - this=this, - expression=key, - ), - } - - EXPRESSION_PARSERS = { - exp.Cluster: lambda self: self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY), - exp.Column: lambda self: self._parse_column(), - exp.Condition: lambda self: self._parse_assignment(), - exp.DataType: lambda self: self._parse_types(allow_identifiers=False, schema=True), - exp.Expression: lambda self: self._parse_expression(), - exp.From: lambda self: self._parse_from(joins=True), - exp.Group: lambda self: self._parse_group(), - exp.Having: lambda self: self._parse_having(), - exp.Hint: lambda self: self._parse_hint_body(), - exp.Identifier: lambda self: self._parse_id_var(), - exp.Join: lambda self: self._parse_join(), - exp.Lambda: lambda self: self._parse_lambda(), - exp.Lateral: lambda self: self._parse_lateral(), - exp.Limit: lambda self: self._parse_limit(), - exp.Offset: lambda self: self._parse_offset(), - exp.Order: lambda self: self._parse_order(), - exp.Ordered: lambda self: self._parse_ordered(), - exp.Properties: lambda self: self._parse_properties(), - exp.PartitionedByProperty: lambda self: self._parse_partitioned_by(), - exp.Qualify: lambda self: self._parse_qualify(), - exp.Returning: lambda self: self._parse_returning(), - exp.Select: lambda self: self._parse_select(), - exp.Sort: lambda self: self._parse_sort(exp.Sort, TokenType.SORT_BY), - exp.Table: lambda self: self._parse_table_parts(), - exp.TableAlias: lambda self: self._parse_table_alias(), - exp.Tuple: lambda self: self._parse_value(values=False), - exp.Whens: lambda self: self._parse_when_matched(), - exp.Where: lambda self: self._parse_where(), - exp.Window: lambda self: self._parse_named_window(), - exp.With: lambda self: self._parse_with(), - "JOIN_TYPE": lambda self: self._parse_join_parts(), - } - - STATEMENT_PARSERS = { - TokenType.ALTER: lambda self: self._parse_alter(), - TokenType.ANALYZE: lambda self: self._parse_analyze(), - TokenType.BEGIN: lambda self: self._parse_transaction(), - TokenType.CACHE: lambda self: self._parse_cache(), - TokenType.COMMENT: lambda self: self._parse_comment(), - TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), - TokenType.COPY: lambda self: self._parse_copy(), - TokenType.CREATE: lambda self: self._parse_create(), - TokenType.DELETE: lambda self: self._parse_delete(), - TokenType.DESC: lambda self: self._parse_describe(), - TokenType.DESCRIBE: lambda self: self._parse_describe(), - TokenType.DROP: lambda self: self._parse_drop(), - TokenType.GRANT: lambda self: self._parse_grant(), - TokenType.INSERT: lambda self: self._parse_insert(), - TokenType.KILL: lambda self: self._parse_kill(), - TokenType.LOAD: lambda self: self._parse_load(), - TokenType.MERGE: lambda self: self._parse_merge(), - TokenType.PIVOT: lambda self: self._parse_simplified_pivot(), - TokenType.PRAGMA: lambda self: self.expression(exp.Pragma, this=self._parse_expression()), - TokenType.REFRESH: lambda self: self._parse_refresh(), - TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), - TokenType.SET: lambda self: self._parse_set(), - TokenType.TRUNCATE: lambda self: self._parse_truncate_table(), - TokenType.UNCACHE: lambda self: self._parse_uncache(), - TokenType.UNPIVOT: lambda self: self._parse_simplified_pivot(is_unpivot=True), - TokenType.UPDATE: lambda self: self._parse_update(), - TokenType.USE: lambda self: self._parse_use(), - TokenType.SEMICOLON: lambda self: exp.Semicolon(), - } - - UNARY_PARSERS = { - TokenType.PLUS: lambda self: self._parse_unary(), # Unary + is handled as a no-op - TokenType.NOT: lambda self: self.expression(exp.Not, this=self._parse_equality()), - TokenType.TILDA: lambda self: self.expression(exp.BitwiseNot, this=self._parse_unary()), - TokenType.DASH: lambda self: self.expression(exp.Neg, this=self._parse_unary()), - TokenType.PIPE_SLASH: lambda self: self.expression(exp.Sqrt, this=self._parse_unary()), - TokenType.DPIPE_SLASH: lambda self: self.expression(exp.Cbrt, this=self._parse_unary()), - } - - STRING_PARSERS = { - TokenType.HEREDOC_STRING: lambda self, token: self.expression( - exp.RawString, this=token.text - ), - TokenType.NATIONAL_STRING: lambda self, token: self.expression( - exp.National, this=token.text - ), - TokenType.RAW_STRING: lambda self, token: self.expression(exp.RawString, this=token.text), - TokenType.STRING: lambda self, token: self.expression( - exp.Literal, this=token.text, is_string=True - ), - TokenType.UNICODE_STRING: lambda self, token: self.expression( - exp.UnicodeString, - this=token.text, - escape=self._match_text_seq("UESCAPE") and self._parse_string(), - ), - } - - NUMERIC_PARSERS = { - TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text), - TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), - TokenType.HEX_STRING: lambda self, token: self.expression( - exp.HexString, - this=token.text, - is_integer=self.dialect.HEX_STRING_IS_INTEGER_TYPE or None, - ), - TokenType.NUMBER: lambda self, token: self.expression( - exp.Literal, this=token.text, is_string=False - ), - } - - PRIMARY_PARSERS = { - **STRING_PARSERS, - **NUMERIC_PARSERS, - TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), - TokenType.NULL: lambda self, _: self.expression(exp.Null), - TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True), - TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False), - TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), - TokenType.STAR: lambda self, _: self._parse_star_ops(), - } - - PLACEHOLDER_PARSERS = { - TokenType.PLACEHOLDER: lambda self: self.expression(exp.Placeholder), - TokenType.PARAMETER: lambda self: self._parse_parameter(), - TokenType.COLON: lambda self: ( - self.expression(exp.Placeholder, this=self._prev.text) - if self._match_set(self.COLON_PLACEHOLDER_TOKENS) - else None - ), - } - - RANGE_PARSERS = { - TokenType.AT_GT: binary_range_parser(exp.ArrayContainsAll), - TokenType.BETWEEN: lambda self, this: self._parse_between(this), - TokenType.GLOB: binary_range_parser(exp.Glob), - TokenType.ILIKE: binary_range_parser(exp.ILike), - TokenType.IN: lambda self, this: self._parse_in(this), - TokenType.IRLIKE: binary_range_parser(exp.RegexpILike), - TokenType.IS: lambda self, this: self._parse_is(this), - TokenType.LIKE: binary_range_parser(exp.Like), - TokenType.LT_AT: binary_range_parser(exp.ArrayContainsAll, reverse_args=True), - TokenType.OVERLAPS: binary_range_parser(exp.Overlaps), - TokenType.RLIKE: binary_range_parser(exp.RegexpLike), - TokenType.SIMILAR_TO: binary_range_parser(exp.SimilarTo), - TokenType.FOR: lambda self, this: self._parse_comprehension(this), - } - - PIPE_SYNTAX_TRANSFORM_PARSERS = { - "AGGREGATE": lambda self, query: self._parse_pipe_syntax_aggregate(query), - "AS": lambda self, query: self._build_pipe_cte( - query, [exp.Star()], self._parse_table_alias() - ), - "EXTEND": lambda self, query: self._parse_pipe_syntax_extend(query), - "LIMIT": lambda self, query: self._parse_pipe_syntax_limit(query), - "ORDER BY": lambda self, query: query.order_by( - self._parse_order(), append=False, copy=False - ), - "PIVOT": lambda self, query: self._parse_pipe_syntax_pivot(query), - "SELECT": lambda self, query: self._parse_pipe_syntax_select(query), - "UNPIVOT": lambda self, query: self._parse_pipe_syntax_pivot(query), - "WHERE": lambda self, query: query.where(self._parse_where(), copy=False), - } - - PROPERTY_PARSERS: t.Dict[str, t.Callable] = { - "ALLOWED_VALUES": lambda self: self.expression( - exp.AllowedValuesProperty, expressions=self._parse_csv(self._parse_primary) - ), - "ALGORITHM": lambda self: self._parse_property_assignment(exp.AlgorithmProperty), - "AUTO": lambda self: self._parse_auto_property(), - "AUTO_INCREMENT": lambda self: self._parse_property_assignment(exp.AutoIncrementProperty), - "BACKUP": lambda self: self.expression( - exp.BackupProperty, this=self._parse_var(any_token=True) - ), - "BLOCKCOMPRESSION": lambda self: self._parse_blockcompression(), - "CHARSET": lambda self, **kwargs: self._parse_character_set(**kwargs), - "CHARACTER SET": lambda self, **kwargs: self._parse_character_set(**kwargs), - "CHECKSUM": lambda self: self._parse_checksum(), - "CLUSTER BY": lambda self: self._parse_cluster(), - "CLUSTERED": lambda self: self._parse_clustered_by(), - "COLLATE": lambda self, **kwargs: self._parse_property_assignment( - exp.CollateProperty, **kwargs - ), - "COMMENT": lambda self: self._parse_property_assignment(exp.SchemaCommentProperty), - "CONTAINS": lambda self: self._parse_contains_property(), - "COPY": lambda self: self._parse_copy_property(), - "DATABLOCKSIZE": lambda self, **kwargs: self._parse_datablocksize(**kwargs), - "DATA_DELETION": lambda self: self._parse_data_deletion_property(), - "DEFINER": lambda self: self._parse_definer(), - "DETERMINISTIC": lambda self: self.expression( - exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") - ), - "DISTRIBUTED": lambda self: self._parse_distributed_property(), - "DUPLICATE": lambda self: self._parse_composite_key_property(exp.DuplicateKeyProperty), - "DYNAMIC": lambda self: self.expression(exp.DynamicProperty), - "DISTKEY": lambda self: self._parse_distkey(), - "DISTSTYLE": lambda self: self._parse_property_assignment(exp.DistStyleProperty), - "EMPTY": lambda self: self.expression(exp.EmptyProperty), - "ENGINE": lambda self: self._parse_property_assignment(exp.EngineProperty), - "ENVIRONMENT": lambda self: self.expression( - exp.EnviromentProperty, expressions=self._parse_wrapped_csv(self._parse_assignment) - ), - "EXECUTE": lambda self: self._parse_property_assignment(exp.ExecuteAsProperty), - "EXTERNAL": lambda self: self.expression(exp.ExternalProperty), - "FALLBACK": lambda self, **kwargs: self._parse_fallback(**kwargs), - "FORMAT": lambda self: self._parse_property_assignment(exp.FileFormatProperty), - "FREESPACE": lambda self: self._parse_freespace(), - "GLOBAL": lambda self: self.expression(exp.GlobalProperty), - "HEAP": lambda self: self.expression(exp.HeapProperty), - "ICEBERG": lambda self: self.expression(exp.IcebergProperty), - "IMMUTABLE": lambda self: self.expression( - exp.StabilityProperty, this=exp.Literal.string("IMMUTABLE") - ), - "INHERITS": lambda self: self.expression( - exp.InheritsProperty, expressions=self._parse_wrapped_csv(self._parse_table) - ), - "INPUT": lambda self: self.expression(exp.InputModelProperty, this=self._parse_schema()), - "JOURNAL": lambda self, **kwargs: self._parse_journal(**kwargs), - "LANGUAGE": lambda self: self._parse_property_assignment(exp.LanguageProperty), - "LAYOUT": lambda self: self._parse_dict_property(this="LAYOUT"), - "LIFETIME": lambda self: self._parse_dict_range(this="LIFETIME"), - "LIKE": lambda self: self._parse_create_like(), - "LOCATION": lambda self: self._parse_property_assignment(exp.LocationProperty), - "LOCK": lambda self: self._parse_locking(), - "LOCKING": lambda self: self._parse_locking(), - "LOG": lambda self, **kwargs: self._parse_log(**kwargs), - "MATERIALIZED": lambda self: self.expression(exp.MaterializedProperty), - "MERGEBLOCKRATIO": lambda self, **kwargs: self._parse_mergeblockratio(**kwargs), - "MODIFIES": lambda self: self._parse_modifies_property(), - "MULTISET": lambda self: self.expression(exp.SetProperty, multi=True), - "NO": lambda self: self._parse_no_property(), - "ON": lambda self: self._parse_on_property(), - "ORDER BY": lambda self: self._parse_order(skip_order_token=True), - "OUTPUT": lambda self: self.expression(exp.OutputModelProperty, this=self._parse_schema()), - "PARTITION": lambda self: self._parse_partitioned_of(), - "PARTITION BY": lambda self: self._parse_partitioned_by(), - "PARTITIONED BY": lambda self: self._parse_partitioned_by(), - "PARTITIONED_BY": lambda self: self._parse_partitioned_by(), - "PRIMARY KEY": lambda self: self._parse_primary_key(in_props=True), - "RANGE": lambda self: self._parse_dict_range(this="RANGE"), - "READS": lambda self: self._parse_reads_property(), - "REMOTE": lambda self: self._parse_remote_with_connection(), - "RETURNS": lambda self: self._parse_returns(), - "STRICT": lambda self: self.expression(exp.StrictProperty), - "STREAMING": lambda self: self.expression(exp.StreamingTableProperty), - "ROW": lambda self: self._parse_row(), - "ROW_FORMAT": lambda self: self._parse_property_assignment(exp.RowFormatProperty), - "SAMPLE": lambda self: self.expression( - exp.SampleProperty, this=self._match_text_seq("BY") and self._parse_bitwise() - ), - "SECURE": lambda self: self.expression(exp.SecureProperty), - "SECURITY": lambda self: self._parse_security(), - "SET": lambda self: self.expression(exp.SetProperty, multi=False), - "SETTINGS": lambda self: self._parse_settings_property(), - "SHARING": lambda self: self._parse_property_assignment(exp.SharingProperty), - "SORTKEY": lambda self: self._parse_sortkey(), - "SOURCE": lambda self: self._parse_dict_property(this="SOURCE"), - "STABLE": lambda self: self.expression( - exp.StabilityProperty, this=exp.Literal.string("STABLE") - ), - "STORED": lambda self: self._parse_stored(), - "SYSTEM_VERSIONING": lambda self: self._parse_system_versioning_property(), - "TBLPROPERTIES": lambda self: self._parse_wrapped_properties(), - "TEMP": lambda self: self.expression(exp.TemporaryProperty), - "TEMPORARY": lambda self: self.expression(exp.TemporaryProperty), - "TO": lambda self: self._parse_to_table(), - "TRANSIENT": lambda self: self.expression(exp.TransientProperty), - "TRANSFORM": lambda self: self.expression( - exp.TransformModelProperty, expressions=self._parse_wrapped_csv(self._parse_expression) - ), - "TTL": lambda self: self._parse_ttl(), - "USING": lambda self: self._parse_property_assignment(exp.FileFormatProperty), - "UNLOGGED": lambda self: self.expression(exp.UnloggedProperty), - "VOLATILE": lambda self: self._parse_volatile_property(), - "WITH": lambda self: self._parse_with_property(), - } - - CONSTRAINT_PARSERS = { - "AUTOINCREMENT": lambda self: self._parse_auto_increment(), - "AUTO_INCREMENT": lambda self: self._parse_auto_increment(), - "CASESPECIFIC": lambda self: self.expression(exp.CaseSpecificColumnConstraint, not_=False), - "CHARACTER SET": lambda self: self.expression( - exp.CharacterSetColumnConstraint, this=self._parse_var_or_string() - ), - "CHECK": lambda self: self.expression( - exp.CheckColumnConstraint, - this=self._parse_wrapped(self._parse_assignment), - enforced=self._match_text_seq("ENFORCED"), - ), - "COLLATE": lambda self: self.expression( - exp.CollateColumnConstraint, - this=self._parse_identifier() or self._parse_column(), - ), - "COMMENT": lambda self: self.expression( - exp.CommentColumnConstraint, this=self._parse_string() - ), - "COMPRESS": lambda self: self._parse_compress(), - "CLUSTERED": lambda self: self.expression( - exp.ClusteredColumnConstraint, this=self._parse_wrapped_csv(self._parse_ordered) - ), - "NONCLUSTERED": lambda self: self.expression( - exp.NonClusteredColumnConstraint, this=self._parse_wrapped_csv(self._parse_ordered) - ), - "DEFAULT": lambda self: self.expression( - exp.DefaultColumnConstraint, this=self._parse_bitwise() - ), - "ENCODE": lambda self: self.expression(exp.EncodeColumnConstraint, this=self._parse_var()), - "EPHEMERAL": lambda self: self.expression( - exp.EphemeralColumnConstraint, this=self._parse_bitwise() - ), - "EXCLUDE": lambda self: self.expression( - exp.ExcludeColumnConstraint, this=self._parse_index_params() - ), - "FOREIGN KEY": lambda self: self._parse_foreign_key(), - "FORMAT": lambda self: self.expression( - exp.DateFormatColumnConstraint, this=self._parse_var_or_string() - ), - "GENERATED": lambda self: self._parse_generated_as_identity(), - "IDENTITY": lambda self: self._parse_auto_increment(), - "INLINE": lambda self: self._parse_inline(), - "LIKE": lambda self: self._parse_create_like(), - "NOT": lambda self: self._parse_not_constraint(), - "NULL": lambda self: self.expression(exp.NotNullColumnConstraint, allow_null=True), - "ON": lambda self: ( - self._match(TokenType.UPDATE) - and self.expression(exp.OnUpdateColumnConstraint, this=self._parse_function()) - ) - or self.expression(exp.OnProperty, this=self._parse_id_var()), - "PATH": lambda self: self.expression(exp.PathColumnConstraint, this=self._parse_string()), - "PERIOD": lambda self: self._parse_period_for_system_time(), - "PRIMARY KEY": lambda self: self._parse_primary_key(), - "REFERENCES": lambda self: self._parse_references(match=False), - "TITLE": lambda self: self.expression( - exp.TitleColumnConstraint, this=self._parse_var_or_string() - ), - "TTL": lambda self: self.expression(exp.MergeTreeTTL, expressions=[self._parse_bitwise()]), - "UNIQUE": lambda self: self._parse_unique(), - "UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint), - "WATERMARK": lambda self: self.expression( - exp.WatermarkColumnConstraint, - this=self._match(TokenType.FOR) and self._parse_column(), - expression=self._match(TokenType.ALIAS) and self._parse_disjunction(), - ), - "WITH": lambda self: self.expression( - exp.Properties, expressions=self._parse_wrapped_properties() - ), - "BUCKET": lambda self: self._parse_partitioned_by_bucket_or_truncate(), - "TRUNCATE": lambda self: self._parse_partitioned_by_bucket_or_truncate(), - } - - def _parse_partitioned_by_bucket_or_truncate(self) -> exp.Expression: - klass = ( - exp.PartitionedByBucket - if self._prev.text.upper() == "BUCKET" - else exp.PartitionByTruncate - ) - - args = self._parse_wrapped_csv(lambda: self._parse_primary() or self._parse_column()) - this, expression = seq_get(args, 0), seq_get(args, 1) - - if isinstance(this, exp.Literal): - # Check for Iceberg partition transforms (bucket / truncate) and ensure their arguments are in the right order - # - For Hive, it's `bucket(, )` or `truncate(, )` - # - For Trino, it's reversed - `bucket(, )` or `truncate(, )` - # Both variants are canonicalized in the latter i.e `bucket(, )` - # - # Hive ref: https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-creating-tables.html#querying-iceberg-partitioning - # Trino ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties - this, expression = expression, this - - return self.expression(klass, this=this, expression=expression) - - ALTER_PARSERS = { - "ADD": lambda self: self._parse_alter_table_add(), - "AS": lambda self: self._parse_select(), - "ALTER": lambda self: self._parse_alter_table_alter(), - "CLUSTER BY": lambda self: self._parse_cluster(wrapped=True), - "DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()), - "DROP": lambda self: self._parse_alter_table_drop(), - "RENAME": lambda self: self._parse_alter_table_rename(), - "SET": lambda self: self._parse_alter_table_set(), - "SWAP": lambda self: self.expression( - exp.SwapTable, this=self._match(TokenType.WITH) and self._parse_table(schema=True) - ), - } - - ALTER_ALTER_PARSERS = { - "DISTKEY": lambda self: self._parse_alter_diststyle(), - "DISTSTYLE": lambda self: self._parse_alter_diststyle(), - "SORTKEY": lambda self: self._parse_alter_sortkey(), - "COMPOUND": lambda self: self._parse_alter_sortkey(compound=True), - } - - SCHEMA_UNNAMED_CONSTRAINTS = { - "CHECK", - "EXCLUDE", - "FOREIGN KEY", - "LIKE", - "PERIOD", - "PRIMARY KEY", - "UNIQUE", - "WATERMARK", - "BUCKET", - "TRUNCATE", - } - - NO_PAREN_FUNCTION_PARSERS = { - "ANY": lambda self: self.expression(exp.Any, this=self._parse_bitwise()), - "CASE": lambda self: self._parse_case(), - "CONNECT_BY_ROOT": lambda self: self.expression( - exp.ConnectByRoot, this=self._parse_column() - ), - "IF": lambda self: self._parse_if(), - } - - INVALID_FUNC_NAME_TOKENS = { - TokenType.IDENTIFIER, - TokenType.STRING, - } - - FUNCTIONS_WITH_ALIASED_ARGS = {"STRUCT"} - - KEY_VALUE_DEFINITIONS = (exp.Alias, exp.EQ, exp.PropertyEQ, exp.Slice) - - FUNCTION_PARSERS = { - **{ - name: lambda self: self._parse_max_min_by(exp.ArgMax) for name in exp.ArgMax.sql_names() - }, - **{ - name: lambda self: self._parse_max_min_by(exp.ArgMin) for name in exp.ArgMin.sql_names() - }, - "CAST": lambda self: self._parse_cast(self.STRICT_CAST), - "CEIL": lambda self: self._parse_ceil_floor(exp.Ceil), - "CONVERT": lambda self: self._parse_convert(self.STRICT_CAST), - "DECODE": lambda self: self._parse_decode(), - "EXTRACT": lambda self: self._parse_extract(), - "FLOOR": lambda self: self._parse_ceil_floor(exp.Floor), - "GAP_FILL": lambda self: self._parse_gap_fill(), - "JSON_OBJECT": lambda self: self._parse_json_object(), - "JSON_OBJECTAGG": lambda self: self._parse_json_object(agg=True), - "JSON_TABLE": lambda self: self._parse_json_table(), - "MATCH": lambda self: self._parse_match_against(), - "NORMALIZE": lambda self: self._parse_normalize(), - "OPENJSON": lambda self: self._parse_open_json(), - "OVERLAY": lambda self: self._parse_overlay(), - "POSITION": lambda self: self._parse_position(), - "PREDICT": lambda self: self._parse_predict(), - "SAFE_CAST": lambda self: self._parse_cast(False, safe=True), - "STRING_AGG": lambda self: self._parse_string_agg(), - "SUBSTRING": lambda self: self._parse_substring(), - "TRIM": lambda self: self._parse_trim(), - "TRY_CAST": lambda self: self._parse_cast(False, safe=True), - "TRY_CONVERT": lambda self: self._parse_convert(False, safe=True), - "XMLELEMENT": lambda self: self.expression( - exp.XMLElement, - this=self._match_text_seq("NAME") and self._parse_id_var(), - expressions=self._match(TokenType.COMMA) and self._parse_csv(self._parse_expression), - ), - "XMLTABLE": lambda self: self._parse_xml_table(), - } - - QUERY_MODIFIER_PARSERS = { - TokenType.MATCH_RECOGNIZE: lambda self: ("match", self._parse_match_recognize()), - TokenType.PREWHERE: lambda self: ("prewhere", self._parse_prewhere()), - TokenType.WHERE: lambda self: ("where", self._parse_where()), - TokenType.GROUP_BY: lambda self: ("group", self._parse_group()), - TokenType.HAVING: lambda self: ("having", self._parse_having()), - TokenType.QUALIFY: lambda self: ("qualify", self._parse_qualify()), - TokenType.WINDOW: lambda self: ("windows", self._parse_window_clause()), - TokenType.ORDER_BY: lambda self: ("order", self._parse_order()), - TokenType.LIMIT: lambda self: ("limit", self._parse_limit()), - TokenType.FETCH: lambda self: ("limit", self._parse_limit()), - TokenType.OFFSET: lambda self: ("offset", self._parse_offset()), - TokenType.FOR: lambda self: ("locks", self._parse_locks()), - TokenType.LOCK: lambda self: ("locks", self._parse_locks()), - TokenType.TABLE_SAMPLE: lambda self: ("sample", self._parse_table_sample(as_modifier=True)), - TokenType.USING: lambda self: ("sample", self._parse_table_sample(as_modifier=True)), - TokenType.CLUSTER_BY: lambda self: ( - "cluster", - self._parse_sort(exp.Cluster, TokenType.CLUSTER_BY), - ), - TokenType.DISTRIBUTE_BY: lambda self: ( - "distribute", - self._parse_sort(exp.Distribute, TokenType.DISTRIBUTE_BY), - ), - TokenType.SORT_BY: lambda self: ("sort", self._parse_sort(exp.Sort, TokenType.SORT_BY)), - TokenType.CONNECT_BY: lambda self: ("connect", self._parse_connect(skip_start_token=True)), - TokenType.START_WITH: lambda self: ("connect", self._parse_connect()), - } - - SET_PARSERS = { - "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"), - "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"), - "SESSION": lambda self: self._parse_set_item_assignment("SESSION"), - "TRANSACTION": lambda self: self._parse_set_transaction(), - } - - SHOW_PARSERS: t.Dict[str, t.Callable] = {} - - TYPE_LITERAL_PARSERS = { - exp.DataType.Type.JSON: lambda self, this, _: self.expression(exp.ParseJSON, this=this), - } - - TYPE_CONVERTERS: t.Dict[exp.DataType.Type, t.Callable[[exp.DataType], exp.DataType]] = {} - - DDL_SELECT_TOKENS = {TokenType.SELECT, TokenType.WITH, TokenType.L_PAREN} - - PRE_VOLATILE_TOKENS = {TokenType.CREATE, TokenType.REPLACE, TokenType.UNIQUE} - - TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} - TRANSACTION_CHARACTERISTICS: OPTIONS_TYPE = { - "ISOLATION": ( - ("LEVEL", "REPEATABLE", "READ"), - ("LEVEL", "READ", "COMMITTED"), - ("LEVEL", "READ", "UNCOMITTED"), - ("LEVEL", "SERIALIZABLE"), - ), - "READ": ("WRITE", "ONLY"), - } - - CONFLICT_ACTIONS: OPTIONS_TYPE = dict.fromkeys( - ("ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK", "UPDATE"), tuple() - ) - CONFLICT_ACTIONS["DO"] = ("NOTHING", "UPDATE") - - CREATE_SEQUENCE: OPTIONS_TYPE = { - "SCALE": ("EXTEND", "NOEXTEND"), - "SHARD": ("EXTEND", "NOEXTEND"), - "NO": ("CYCLE", "CACHE", "MAXVALUE", "MINVALUE"), - **dict.fromkeys( - ( - "SESSION", - "GLOBAL", - "KEEP", - "NOKEEP", - "ORDER", - "NOORDER", - "NOCACHE", - "CYCLE", - "NOCYCLE", - "NOMINVALUE", - "NOMAXVALUE", - "NOSCALE", - "NOSHARD", - ), - tuple(), - ), - } - - ISOLATED_LOADING_OPTIONS: OPTIONS_TYPE = {"FOR": ("ALL", "INSERT", "NONE")} - - USABLES: OPTIONS_TYPE = dict.fromkeys( - ("ROLE", "WAREHOUSE", "DATABASE", "SCHEMA", "CATALOG"), tuple() - ) - - CAST_ACTIONS: OPTIONS_TYPE = dict.fromkeys(("RENAME", "ADD"), ("FIELDS",)) - - SCHEMA_BINDING_OPTIONS: OPTIONS_TYPE = { - "TYPE": ("EVOLUTION",), - **dict.fromkeys(("BINDING", "COMPENSATION", "EVOLUTION"), tuple()), - } - - PROCEDURE_OPTIONS: OPTIONS_TYPE = {} - - EXECUTE_AS_OPTIONS: OPTIONS_TYPE = dict.fromkeys(("CALLER", "SELF", "OWNER"), tuple()) - - KEY_CONSTRAINT_OPTIONS: OPTIONS_TYPE = { - "NOT": ("ENFORCED",), - "MATCH": ( - "FULL", - "PARTIAL", - "SIMPLE", - ), - "INITIALLY": ("DEFERRED", "IMMEDIATE"), - "USING": ( - "BTREE", - "HASH", - ), - **dict.fromkeys(("DEFERRABLE", "NORELY", "RELY"), tuple()), - } - - WINDOW_EXCLUDE_OPTIONS: OPTIONS_TYPE = { - "NO": ("OTHERS",), - "CURRENT": ("ROW",), - **dict.fromkeys(("GROUP", "TIES"), tuple()), - } - - INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"} - - CLONE_KEYWORDS = {"CLONE", "COPY"} - HISTORICAL_DATA_PREFIX = {"AT", "BEFORE", "END"} - HISTORICAL_DATA_KIND = {"OFFSET", "STATEMENT", "STREAM", "TIMESTAMP", "VERSION"} - - OPCLASS_FOLLOW_KEYWORDS = {"ASC", "DESC", "NULLS", "WITH"} - - OPTYPE_FOLLOW_TOKENS = {TokenType.COMMA, TokenType.R_PAREN} - - TABLE_INDEX_HINT_TOKENS = {TokenType.FORCE, TokenType.IGNORE, TokenType.USE} - - VIEW_ATTRIBUTES = {"ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"} - - WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} - WINDOW_BEFORE_PAREN_TOKENS = {TokenType.OVER} - WINDOW_SIDES = {"FOLLOWING", "PRECEDING"} - - JSON_KEY_VALUE_SEPARATOR_TOKENS = {TokenType.COLON, TokenType.COMMA, TokenType.IS} - - FETCH_TOKENS = ID_VAR_TOKENS - {TokenType.ROW, TokenType.ROWS, TokenType.PERCENT} - - ADD_CONSTRAINT_TOKENS = { - TokenType.CONSTRAINT, - TokenType.FOREIGN_KEY, - TokenType.INDEX, - TokenType.KEY, - TokenType.PRIMARY_KEY, - TokenType.UNIQUE, - } - - DISTINCT_TOKENS = {TokenType.DISTINCT} - - NULL_TOKENS = {TokenType.NULL} - - UNNEST_OFFSET_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - SET_OPERATIONS - - SELECT_START_TOKENS = {TokenType.L_PAREN, TokenType.WITH, TokenType.SELECT} - - COPY_INTO_VARLEN_OPTIONS = {"FILE_FORMAT", "COPY_OPTIONS", "FORMAT_OPTIONS", "CREDENTIAL"} - - IS_JSON_PREDICATE_KIND = {"VALUE", "SCALAR", "ARRAY", "OBJECT"} - - ODBC_DATETIME_LITERALS = { - "d": exp.Date, - "t": exp.Time, - "ts": exp.Timestamp, - } - - ON_CONDITION_TOKENS = {"ERROR", "NULL", "TRUE", "FALSE", "EMPTY"} - - PRIVILEGE_FOLLOW_TOKENS = {TokenType.ON, TokenType.COMMA, TokenType.L_PAREN} - - # The style options for the DESCRIBE statement - DESCRIBE_STYLES = {"ANALYZE", "EXTENDED", "FORMATTED", "HISTORY"} - - # The style options for the ANALYZE statement - ANALYZE_STYLES = { - "BUFFER_USAGE_LIMIT", - "FULL", - "LOCAL", - "NO_WRITE_TO_BINLOG", - "SAMPLE", - "SKIP_LOCKED", - "VERBOSE", - } - - ANALYZE_EXPRESSION_PARSERS = { - "ALL": lambda self: self._parse_analyze_columns(), - "COMPUTE": lambda self: self._parse_analyze_statistics(), - "DELETE": lambda self: self._parse_analyze_delete(), - "DROP": lambda self: self._parse_analyze_histogram(), - "ESTIMATE": lambda self: self._parse_analyze_statistics(), - "LIST": lambda self: self._parse_analyze_list(), - "PREDICATE": lambda self: self._parse_analyze_columns(), - "UPDATE": lambda self: self._parse_analyze_histogram(), - "VALIDATE": lambda self: self._parse_analyze_validate(), - } - - PARTITION_KEYWORDS = {"PARTITION", "SUBPARTITION"} - - AMBIGUOUS_ALIAS_TOKENS = (TokenType.LIMIT, TokenType.OFFSET) - - OPERATION_MODIFIERS: t.Set[str] = set() - - RECURSIVE_CTE_SEARCH_KIND = {"BREADTH", "DEPTH", "CYCLE"} - - MODIFIABLES = (exp.Query, exp.Table, exp.TableFromRows) - - STRICT_CAST = True - - PREFIXED_PIVOT_COLUMNS = False - IDENTIFY_PIVOT_STRINGS = False - - LOG_DEFAULTS_TO_LN = False - - # Whether the table sample clause expects CSV syntax - TABLESAMPLE_CSV = False - - # The default method used for table sampling - DEFAULT_SAMPLING_METHOD: t.Optional[str] = None - - # Whether the SET command needs a delimiter (e.g. "=") for assignments - SET_REQUIRES_ASSIGNMENT_DELIMITER = True - - # Whether the TRIM function expects the characters to trim as its first argument - TRIM_PATTERN_FIRST = False - - # Whether string aliases are supported `SELECT COUNT(*) 'count'` - STRING_ALIASES = False - - # Whether query modifiers such as LIMIT are attached to the UNION node (vs its right operand) - MODIFIERS_ATTACHED_TO_SET_OP = True - SET_OP_MODIFIERS = {"order", "limit", "offset"} - - # Whether to parse IF statements that aren't followed by a left parenthesis as commands - NO_PAREN_IF_COMMANDS = True - - # Whether the -> and ->> operators expect documents of type JSON (e.g. Postgres) - JSON_ARROWS_REQUIRE_JSON_TYPE = False - - # Whether the `:` operator is used to extract a value from a VARIANT column - COLON_IS_VARIANT_EXTRACT = False - - # Whether or not a VALUES keyword needs to be followed by '(' to form a VALUES clause. - # If this is True and '(' is not found, the keyword will be treated as an identifier - VALUES_FOLLOWED_BY_PAREN = True - - # Whether implicit unnesting is supported, e.g. SELECT 1 FROM y.z AS z, z.a (Redshift) - SUPPORTS_IMPLICIT_UNNEST = False - - # Whether or not interval spans are supported, INTERVAL 1 YEAR TO MONTHS - INTERVAL_SPANS = True - - # Whether a PARTITION clause can follow a table reference - SUPPORTS_PARTITION_SELECTION = False - - # Whether the `name AS expr` schema/column constraint requires parentheses around `expr` - WRAPPED_TRANSFORM_COLUMN_CONSTRAINT = True - - # Whether the 'AS' keyword is optional in the CTE definition syntax - OPTIONAL_ALIAS_TOKEN_CTE = True - - # Whether renaming a column with an ALTER statement requires the presence of the COLUMN keyword - ALTER_RENAME_REQUIRES_COLUMN = True - - __slots__ = ( - "error_level", - "error_message_context", - "max_errors", - "dialect", - "sql", - "errors", - "_tokens", - "_index", - "_curr", - "_next", - "_prev", - "_prev_comments", - "_pipe_cte_counter", - ) - - # Autofilled - SHOW_TRIE: t.Dict = {} - SET_TRIE: t.Dict = {} - - def __init__( - self, - error_level: t.Optional[ErrorLevel] = None, - error_message_context: int = 100, - max_errors: int = 3, - dialect: DialectType = None, - ): - from sqlglot.dialects import Dialect - - self.error_level = error_level or ErrorLevel.IMMEDIATE - self.error_message_context = error_message_context - self.max_errors = max_errors - self.dialect = Dialect.get_or_raise(dialect) - self.reset() - - def reset(self): - self.sql = "" - self.errors = [] - self._tokens = [] - self._index = 0 - self._curr = None - self._next = None - self._prev = None - self._prev_comments = None - self._pipe_cte_counter = 0 - - def parse( - self, raw_tokens: t.List[Token], sql: t.Optional[str] = None - ) -> t.List[t.Optional[exp.Expression]]: - """ - Parses a list of tokens and returns a list of syntax trees, one tree - per parsed SQL statement. - - Args: - raw_tokens: The list of tokens. - sql: The original SQL string, used to produce helpful debug messages. - - Returns: - The list of the produced syntax trees. - """ - return self._parse( - parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql - ) - - def parse_into( - self, - expression_types: exp.IntoType, - raw_tokens: t.List[Token], - sql: t.Optional[str] = None, - ) -> t.List[t.Optional[exp.Expression]]: - """ - Parses a list of tokens into a given Expression type. If a collection of Expression - types is given instead, this method will try to parse the token list into each one - of them, stopping at the first for which the parsing succeeds. - - Args: - expression_types: The expression type(s) to try and parse the token list into. - raw_tokens: The list of tokens. - sql: The original SQL string, used to produce helpful debug messages. - - Returns: - The target Expression. - """ - errors = [] - for expression_type in ensure_list(expression_types): - parser = self.EXPRESSION_PARSERS.get(expression_type) - if not parser: - raise TypeError(f"No parser registered for {expression_type}") - - try: - return self._parse(parser, raw_tokens, sql) - except ParseError as e: - e.errors[0]["into_expression"] = expression_type - errors.append(e) - - raise ParseError( - f"Failed to parse '{sql or raw_tokens}' into {expression_types}", - errors=merge_errors(errors), - ) from errors[-1] - - def _parse( - self, - parse_method: t.Callable[[Parser], t.Optional[exp.Expression]], - raw_tokens: t.List[Token], - sql: t.Optional[str] = None, - ) -> t.List[t.Optional[exp.Expression]]: - self.reset() - self.sql = sql or "" - - total = len(raw_tokens) - chunks: t.List[t.List[Token]] = [[]] - - for i, token in enumerate(raw_tokens): - if token.token_type == TokenType.SEMICOLON: - if token.comments: - chunks.append([token]) - - if i < total - 1: - chunks.append([]) - else: - chunks[-1].append(token) - - expressions = [] - - for tokens in chunks: - self._index = -1 - self._tokens = tokens - self._advance() - - expressions.append(parse_method(self)) - - if self._index < len(self._tokens): - self.raise_error("Invalid expression / Unexpected token") - - self.check_errors() - - return expressions - - def check_errors(self) -> None: - """Logs or raises any found errors, depending on the chosen error level setting.""" - if self.error_level == ErrorLevel.WARN: - for error in self.errors: - logger.error(str(error)) - elif self.error_level == ErrorLevel.RAISE and self.errors: - raise ParseError( - concat_messages(self.errors, self.max_errors), - errors=merge_errors(self.errors), - ) - - def raise_error(self, message: str, token: t.Optional[Token] = None) -> None: - """ - Appends an error in the list of recorded errors or raises it, depending on the chosen - error level setting. - """ - token = token or self._curr or self._prev or Token.string("") - start = token.start - end = token.end + 1 - start_context = self.sql[max(start - self.error_message_context, 0) : start] - highlight = self.sql[start:end] - end_context = self.sql[end : end + self.error_message_context] - - error = ParseError.new( - f"{message}. Line {token.line}, Col: {token.col}.\n" - f" {start_context}\033[4m{highlight}\033[0m{end_context}", - description=message, - line=token.line, - col=token.col, - start_context=start_context, - highlight=highlight, - end_context=end_context, - ) - - if self.error_level == ErrorLevel.IMMEDIATE: - raise error - - self.errors.append(error) - - def expression( - self, exp_class: t.Type[E], comments: t.Optional[t.List[str]] = None, **kwargs - ) -> E: - """ - Creates a new, validated Expression. - - Args: - exp_class: The expression class to instantiate. - comments: An optional list of comments to attach to the expression. - kwargs: The arguments to set for the expression along with their respective values. - - Returns: - The target expression. - """ - instance = exp_class(**kwargs) - instance.add_comments(comments) if comments else self._add_comments(instance) - return self.validate_expression(instance) - - def _add_comments(self, expression: t.Optional[exp.Expression]) -> None: - if expression and self._prev_comments: - expression.add_comments(self._prev_comments) - self._prev_comments = None - - def validate_expression(self, expression: E, args: t.Optional[t.List] = None) -> E: - """ - Validates an Expression, making sure that all its mandatory arguments are set. - - Args: - expression: The expression to validate. - args: An optional list of items that was used to instantiate the expression, if it's a Func. - - Returns: - The validated expression. - """ - if self.error_level != ErrorLevel.IGNORE: - for error_message in expression.error_messages(args): - self.raise_error(error_message) - - return expression - - def _find_sql(self, start: Token, end: Token) -> str: - return self.sql[start.start : end.end + 1] - - def _is_connected(self) -> bool: - return self._prev and self._curr and self._prev.end + 1 == self._curr.start - - def _advance(self, times: int = 1) -> None: - self._index += times - self._curr = seq_get(self._tokens, self._index) - self._next = seq_get(self._tokens, self._index + 1) - - if self._index > 0: - self._prev = self._tokens[self._index - 1] - self._prev_comments = self._prev.comments - else: - self._prev = None - self._prev_comments = None - - def _retreat(self, index: int) -> None: - if index != self._index: - self._advance(index - self._index) - - def _warn_unsupported(self) -> None: - if len(self._tokens) <= 1: - return - - # We use _find_sql because self.sql may comprise multiple chunks, and we're only - # interested in emitting a warning for the one being currently processed. - sql = self._find_sql(self._tokens[0], self._tokens[-1])[: self.error_message_context] - - logger.warning( - f"'{sql}' contains unsupported syntax. Falling back to parsing as a 'Command'." - ) - - def _parse_command(self) -> exp.Command: - self._warn_unsupported() - return self.expression( - exp.Command, - comments=self._prev_comments, - this=self._prev.text.upper(), - expression=self._parse_string(), - ) - - def _try_parse(self, parse_method: t.Callable[[], T], retreat: bool = False) -> t.Optional[T]: - """ - Attemps to backtrack if a parse function that contains a try/catch internally raises an error. - This behavior can be different depending on the uset-set ErrorLevel, so _try_parse aims to - solve this by setting & resetting the parser state accordingly - """ - index = self._index - error_level = self.error_level - - self.error_level = ErrorLevel.IMMEDIATE - try: - this = parse_method() - except ParseError: - this = None - finally: - if not this or retreat: - self._retreat(index) - self.error_level = error_level - - return this - - def _parse_comment(self, allow_exists: bool = True) -> exp.Expression: - start = self._prev - exists = self._parse_exists() if allow_exists else None - - self._match(TokenType.ON) - - materialized = self._match_text_seq("MATERIALIZED") - kind = self._match_set(self.CREATABLES) and self._prev - if not kind: - return self._parse_as_command(start) - - if kind.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): - this = self._parse_user_defined_function(kind=kind.token_type) - elif kind.token_type == TokenType.TABLE: - this = self._parse_table(alias_tokens=self.COMMENT_TABLE_ALIAS_TOKENS) - elif kind.token_type == TokenType.COLUMN: - this = self._parse_column() - else: - this = self._parse_id_var() - - self._match(TokenType.IS) - - return self.expression( - exp.Comment, - this=this, - kind=kind.text, - expression=self._parse_string(), - exists=exists, - materialized=materialized, - ) - - def _parse_to_table( - self, - ) -> exp.ToTableProperty: - table = self._parse_table_parts(schema=True) - return self.expression(exp.ToTableProperty, this=table) - - # https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree#mergetree-table-ttl - def _parse_ttl(self) -> exp.Expression: - def _parse_ttl_action() -> t.Optional[exp.Expression]: - this = self._parse_bitwise() - - if self._match_text_seq("DELETE"): - return self.expression(exp.MergeTreeTTLAction, this=this, delete=True) - if self._match_text_seq("RECOMPRESS"): - return self.expression( - exp.MergeTreeTTLAction, this=this, recompress=self._parse_bitwise() - ) - if self._match_text_seq("TO", "DISK"): - return self.expression( - exp.MergeTreeTTLAction, this=this, to_disk=self._parse_string() - ) - if self._match_text_seq("TO", "VOLUME"): - return self.expression( - exp.MergeTreeTTLAction, this=this, to_volume=self._parse_string() - ) - - return this - - expressions = self._parse_csv(_parse_ttl_action) - where = self._parse_where() - group = self._parse_group() - - aggregates = None - if group and self._match(TokenType.SET): - aggregates = self._parse_csv(self._parse_set_item) - - return self.expression( - exp.MergeTreeTTL, - expressions=expressions, - where=where, - group=group, - aggregates=aggregates, - ) - - def _parse_statement(self) -> t.Optional[exp.Expression]: - if self._curr is None: - return None - - if self._match_set(self.STATEMENT_PARSERS): - comments = self._prev_comments - stmt = self.STATEMENT_PARSERS[self._prev.token_type](self) - stmt.add_comments(comments, prepend=True) - return stmt - - if self._match_set(self.dialect.tokenizer.COMMANDS): - return self._parse_command() - - expression = self._parse_expression() - expression = self._parse_set_operations(expression) if expression else self._parse_select() - return self._parse_query_modifiers(expression) - - def _parse_drop(self, exists: bool = False) -> exp.Drop | exp.Command: - start = self._prev - temporary = self._match(TokenType.TEMPORARY) - materialized = self._match_text_seq("MATERIALIZED") - - kind = self._match_set(self.CREATABLES) and self._prev.text.upper() - if not kind: - return self._parse_as_command(start) - - concurrently = self._match_text_seq("CONCURRENTLY") - if_exists = exists or self._parse_exists() - - if kind == "COLUMN": - this = self._parse_column() - else: - this = self._parse_table_parts( - schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA - ) - - cluster = self._parse_on_property() if self._match(TokenType.ON) else None - - if self._match(TokenType.L_PAREN, advance=False): - expressions = self._parse_wrapped_csv(self._parse_types) - else: - expressions = None - - return self.expression( - exp.Drop, - exists=if_exists, - this=this, - expressions=expressions, - kind=self.dialect.CREATABLE_KIND_MAPPING.get(kind) or kind, - temporary=temporary, - materialized=materialized, - cascade=self._match_text_seq("CASCADE"), - constraints=self._match_text_seq("CONSTRAINTS"), - purge=self._match_text_seq("PURGE"), - cluster=cluster, - concurrently=concurrently, - ) - - def _parse_exists(self, not_: bool = False) -> t.Optional[bool]: - return ( - self._match_text_seq("IF") - and (not not_ or self._match(TokenType.NOT)) - and self._match(TokenType.EXISTS) - ) - - def _parse_create(self) -> exp.Create | exp.Command: - # Note: this can't be None because we've matched a statement parser - start = self._prev - - replace = ( - start.token_type == TokenType.REPLACE - or self._match_pair(TokenType.OR, TokenType.REPLACE) - or self._match_pair(TokenType.OR, TokenType.ALTER) - ) - refresh = self._match_pair(TokenType.OR, TokenType.REFRESH) - - unique = self._match(TokenType.UNIQUE) - - if self._match_text_seq("CLUSTERED", "COLUMNSTORE"): - clustered = True - elif self._match_text_seq("NONCLUSTERED", "COLUMNSTORE") or self._match_text_seq( - "COLUMNSTORE" - ): - clustered = False - else: - clustered = None - - if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False): - self._advance() - - properties = None - create_token = self._match_set(self.CREATABLES) and self._prev - - if not create_token: - # exp.Properties.Location.POST_CREATE - properties = self._parse_properties() - create_token = self._match_set(self.CREATABLES) and self._prev - - if not properties or not create_token: - return self._parse_as_command(start) - - concurrently = self._match_text_seq("CONCURRENTLY") - exists = self._parse_exists(not_=True) - this = None - expression: t.Optional[exp.Expression] = None - indexes = None - no_schema_binding = None - begin = None - end = None - clone = None - - def extend_props(temp_props: t.Optional[exp.Properties]) -> None: - nonlocal properties - if properties and temp_props: - properties.expressions.extend(temp_props.expressions) - elif temp_props: - properties = temp_props - - if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): - this = self._parse_user_defined_function(kind=create_token.token_type) - - # exp.Properties.Location.POST_SCHEMA ("schema" here is the UDF's type signature) - extend_props(self._parse_properties()) - - expression = self._match(TokenType.ALIAS) and self._parse_heredoc() - extend_props(self._parse_properties()) - - if not expression: - if self._match(TokenType.COMMAND): - expression = self._parse_as_command(self._prev) - else: - begin = self._match(TokenType.BEGIN) - return_ = self._match_text_seq("RETURN") - - if self._match(TokenType.STRING, advance=False): - # Takes care of BigQuery's JavaScript UDF definitions that end in an OPTIONS property - # # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#create_function_statement - expression = self._parse_string() - extend_props(self._parse_properties()) - else: - expression = self._parse_user_defined_function_expression() - - end = self._match_text_seq("END") - - if return_: - expression = self.expression(exp.Return, this=expression) - elif create_token.token_type == TokenType.INDEX: - # Postgres allows anonymous indexes, eg. CREATE INDEX IF NOT EXISTS ON t(c) - if not self._match(TokenType.ON): - index = self._parse_id_var() - anonymous = False - else: - index = None - anonymous = True - - this = self._parse_index(index=index, anonymous=anonymous) - elif create_token.token_type in self.DB_CREATABLES: - table_parts = self._parse_table_parts( - schema=True, is_db_reference=create_token.token_type == TokenType.SCHEMA - ) - - # exp.Properties.Location.POST_NAME - self._match(TokenType.COMMA) - extend_props(self._parse_properties(before=True)) - - this = self._parse_schema(this=table_parts) - - # exp.Properties.Location.POST_SCHEMA and POST_WITH - extend_props(self._parse_properties()) - - has_alias = self._match(TokenType.ALIAS) - if not self._match_set(self.DDL_SELECT_TOKENS, advance=False): - # exp.Properties.Location.POST_ALIAS - extend_props(self._parse_properties()) - - if create_token.token_type == TokenType.SEQUENCE: - expression = self._parse_types() - extend_props(self._parse_properties()) - else: - expression = self._parse_ddl_select() - - # Some dialects also support using a table as an alias instead of a SELECT. - # Here we fallback to this as an alternative. - if not expression and has_alias: - expression = self._try_parse(self._parse_table_parts) - - if create_token.token_type == TokenType.TABLE: - # exp.Properties.Location.POST_EXPRESSION - extend_props(self._parse_properties()) - - indexes = [] - while True: - index = self._parse_index() - - # exp.Properties.Location.POST_INDEX - extend_props(self._parse_properties()) - if not index: - break - else: - self._match(TokenType.COMMA) - indexes.append(index) - elif create_token.token_type == TokenType.VIEW: - if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"): - no_schema_binding = True - elif create_token.token_type in (TokenType.SINK, TokenType.SOURCE): - extend_props(self._parse_properties()) - - shallow = self._match_text_seq("SHALLOW") - - if self._match_texts(self.CLONE_KEYWORDS): - copy = self._prev.text.lower() == "copy" - clone = self.expression( - exp.Clone, this=self._parse_table(schema=True), shallow=shallow, copy=copy - ) - - if self._curr and not self._match_set((TokenType.R_PAREN, TokenType.COMMA), advance=False): - return self._parse_as_command(start) - - create_kind_text = create_token.text.upper() - return self.expression( - exp.Create, - this=this, - kind=self.dialect.CREATABLE_KIND_MAPPING.get(create_kind_text) or create_kind_text, - replace=replace, - refresh=refresh, - unique=unique, - expression=expression, - exists=exists, - properties=properties, - indexes=indexes, - no_schema_binding=no_schema_binding, - begin=begin, - end=end, - clone=clone, - concurrently=concurrently, - clustered=clustered, - ) - - def _parse_sequence_properties(self) -> t.Optional[exp.SequenceProperties]: - seq = exp.SequenceProperties() - - options = [] - index = self._index - - while self._curr: - self._match(TokenType.COMMA) - if self._match_text_seq("INCREMENT"): - self._match_text_seq("BY") - self._match_text_seq("=") - seq.set("increment", self._parse_term()) - elif self._match_text_seq("MINVALUE"): - seq.set("minvalue", self._parse_term()) - elif self._match_text_seq("MAXVALUE"): - seq.set("maxvalue", self._parse_term()) - elif self._match(TokenType.START_WITH) or self._match_text_seq("START"): - self._match_text_seq("=") - seq.set("start", self._parse_term()) - elif self._match_text_seq("CACHE"): - # T-SQL allows empty CACHE which is initialized dynamically - seq.set("cache", self._parse_number() or True) - elif self._match_text_seq("OWNED", "BY"): - # "OWNED BY NONE" is the default - seq.set("owned", None if self._match_text_seq("NONE") else self._parse_column()) - else: - opt = self._parse_var_from_options(self.CREATE_SEQUENCE, raise_unmatched=False) - if opt: - options.append(opt) - else: - break - - seq.set("options", options if options else None) - return None if self._index == index else seq - - def _parse_property_before(self) -> t.Optional[exp.Expression]: - # only used for teradata currently - self._match(TokenType.COMMA) - - kwargs = { - "no": self._match_text_seq("NO"), - "dual": self._match_text_seq("DUAL"), - "before": self._match_text_seq("BEFORE"), - "default": self._match_text_seq("DEFAULT"), - "local": (self._match_text_seq("LOCAL") and "LOCAL") - or (self._match_text_seq("NOT", "LOCAL") and "NOT LOCAL"), - "after": self._match_text_seq("AFTER"), - "minimum": self._match_texts(("MIN", "MINIMUM")), - "maximum": self._match_texts(("MAX", "MAXIMUM")), - } - - if self._match_texts(self.PROPERTY_PARSERS): - parser = self.PROPERTY_PARSERS[self._prev.text.upper()] - try: - return parser(self, **{k: v for k, v in kwargs.items() if v}) - except TypeError: - self.raise_error(f"Cannot parse property '{self._prev.text}'") - - return None - - def _parse_wrapped_properties(self) -> t.List[exp.Expression]: - return self._parse_wrapped_csv(self._parse_property) - - def _parse_property(self) -> t.Optional[exp.Expression]: - if self._match_texts(self.PROPERTY_PARSERS): - return self.PROPERTY_PARSERS[self._prev.text.upper()](self) - - if self._match(TokenType.DEFAULT) and self._match_texts(self.PROPERTY_PARSERS): - return self.PROPERTY_PARSERS[self._prev.text.upper()](self, default=True) - - if self._match_text_seq("COMPOUND", "SORTKEY"): - return self._parse_sortkey(compound=True) - - if self._match_text_seq("SQL", "SECURITY"): - return self.expression(exp.SqlSecurityProperty, definer=self._match_text_seq("DEFINER")) - - index = self._index - key = self._parse_column() - - if not self._match(TokenType.EQ): - self._retreat(index) - return self._parse_sequence_properties() - - # Transform the key to exp.Dot if it's dotted identifiers wrapped in exp.Column or to exp.Var otherwise - if isinstance(key, exp.Column): - key = key.to_dot() if len(key.parts) > 1 else exp.var(key.name) - - value = self._parse_bitwise() or self._parse_var(any_token=True) - - # Transform the value to exp.Var if it was parsed as exp.Column(exp.Identifier()) - if isinstance(value, exp.Column): - value = exp.var(value.name) - - return self.expression(exp.Property, this=key, value=value) - - def _parse_stored(self) -> t.Union[exp.FileFormatProperty, exp.StorageHandlerProperty]: - if self._match_text_seq("BY"): - return self.expression(exp.StorageHandlerProperty, this=self._parse_var_or_string()) - - self._match(TokenType.ALIAS) - input_format = self._parse_string() if self._match_text_seq("INPUTFORMAT") else None - output_format = self._parse_string() if self._match_text_seq("OUTPUTFORMAT") else None - - return self.expression( - exp.FileFormatProperty, - this=( - self.expression( - exp.InputOutputFormat, - input_format=input_format, - output_format=output_format, - ) - if input_format or output_format - else self._parse_var_or_string() or self._parse_number() or self._parse_id_var() - ), - ) - - def _parse_unquoted_field(self) -> t.Optional[exp.Expression]: - field = self._parse_field() - if isinstance(field, exp.Identifier) and not field.quoted: - field = exp.var(field) - - return field - - def _parse_property_assignment(self, exp_class: t.Type[E], **kwargs: t.Any) -> E: - self._match(TokenType.EQ) - self._match(TokenType.ALIAS) - - return self.expression(exp_class, this=self._parse_unquoted_field(), **kwargs) - - def _parse_properties(self, before: t.Optional[bool] = None) -> t.Optional[exp.Properties]: - properties = [] - while True: - if before: - prop = self._parse_property_before() - else: - prop = self._parse_property() - if not prop: - break - for p in ensure_list(prop): - properties.append(p) - - if properties: - return self.expression(exp.Properties, expressions=properties) - - return None - - def _parse_fallback(self, no: bool = False) -> exp.FallbackProperty: - return self.expression( - exp.FallbackProperty, no=no, protection=self._match_text_seq("PROTECTION") - ) - - def _parse_security(self) -> t.Optional[exp.SecurityProperty]: - if self._match_texts(("NONE", "DEFINER", "INVOKER")): - security_specifier = self._prev.text.upper() - return self.expression(exp.SecurityProperty, this=security_specifier) - return None - - def _parse_settings_property(self) -> exp.SettingsProperty: - return self.expression( - exp.SettingsProperty, expressions=self._parse_csv(self._parse_assignment) - ) - - def _parse_volatile_property(self) -> exp.VolatileProperty | exp.StabilityProperty: - if self._index >= 2: - pre_volatile_token = self._tokens[self._index - 2] - else: - pre_volatile_token = None - - if pre_volatile_token and pre_volatile_token.token_type in self.PRE_VOLATILE_TOKENS: - return exp.VolatileProperty() - - return self.expression(exp.StabilityProperty, this=exp.Literal.string("VOLATILE")) - - def _parse_retention_period(self) -> exp.Var: - # Parse TSQL's HISTORY_RETENTION_PERIOD: {INFINITE | DAY | DAYS | MONTH ...} - number = self._parse_number() - number_str = f"{number} " if number else "" - unit = self._parse_var(any_token=True) - return exp.var(f"{number_str}{unit}") - - def _parse_system_versioning_property( - self, with_: bool = False - ) -> exp.WithSystemVersioningProperty: - self._match(TokenType.EQ) - prop = self.expression( - exp.WithSystemVersioningProperty, - **{ # type: ignore - "on": True, - "with": with_, - }, - ) - - if self._match_text_seq("OFF"): - prop.set("on", False) - return prop - - self._match(TokenType.ON) - if self._match(TokenType.L_PAREN): - while self._curr and not self._match(TokenType.R_PAREN): - if self._match_text_seq("HISTORY_TABLE", "="): - prop.set("this", self._parse_table_parts()) - elif self._match_text_seq("DATA_CONSISTENCY_CHECK", "="): - prop.set("data_consistency", self._advance_any() and self._prev.text.upper()) - elif self._match_text_seq("HISTORY_RETENTION_PERIOD", "="): - prop.set("retention_period", self._parse_retention_period()) - - self._match(TokenType.COMMA) - - return prop - - def _parse_data_deletion_property(self) -> exp.DataDeletionProperty: - self._match(TokenType.EQ) - on = self._match_text_seq("ON") or not self._match_text_seq("OFF") - prop = self.expression(exp.DataDeletionProperty, on=on) - - if self._match(TokenType.L_PAREN): - while self._curr and not self._match(TokenType.R_PAREN): - if self._match_text_seq("FILTER_COLUMN", "="): - prop.set("filter_column", self._parse_column()) - elif self._match_text_seq("RETENTION_PERIOD", "="): - prop.set("retention_period", self._parse_retention_period()) - - self._match(TokenType.COMMA) - - return prop - - def _parse_distributed_property(self) -> exp.DistributedByProperty: - kind = "HASH" - expressions: t.Optional[t.List[exp.Expression]] = None - if self._match_text_seq("BY", "HASH"): - expressions = self._parse_wrapped_csv(self._parse_id_var) - elif self._match_text_seq("BY", "RANDOM"): - kind = "RANDOM" - - # If the BUCKETS keyword is not present, the number of buckets is AUTO - buckets: t.Optional[exp.Expression] = None - if self._match_text_seq("BUCKETS") and not self._match_text_seq("AUTO"): - buckets = self._parse_number() - - return self.expression( - exp.DistributedByProperty, - expressions=expressions, - kind=kind, - buckets=buckets, - order=self._parse_order(), - ) - - def _parse_composite_key_property(self, expr_type: t.Type[E]) -> E: - self._match_text_seq("KEY") - expressions = self._parse_wrapped_id_vars() - return self.expression(expr_type, expressions=expressions) - - def _parse_with_property(self) -> t.Optional[exp.Expression] | t.List[exp.Expression]: - if self._match_text_seq("(", "SYSTEM_VERSIONING"): - prop = self._parse_system_versioning_property(with_=True) - self._match_r_paren() - return prop - - if self._match(TokenType.L_PAREN, advance=False): - return self._parse_wrapped_properties() - - if self._match_text_seq("JOURNAL"): - return self._parse_withjournaltable() - - if self._match_texts(self.VIEW_ATTRIBUTES): - return self.expression(exp.ViewAttributeProperty, this=self._prev.text.upper()) - - if self._match_text_seq("DATA"): - return self._parse_withdata(no=False) - elif self._match_text_seq("NO", "DATA"): - return self._parse_withdata(no=True) - - if self._match(TokenType.SERDE_PROPERTIES, advance=False): - return self._parse_serde_properties(with_=True) - - if self._match(TokenType.SCHEMA): - return self.expression( - exp.WithSchemaBindingProperty, - this=self._parse_var_from_options(self.SCHEMA_BINDING_OPTIONS), - ) - - if self._match_texts(self.PROCEDURE_OPTIONS, advance=False): - return self.expression( - exp.WithProcedureOptions, expressions=self._parse_csv(self._parse_procedure_option) - ) - - if not self._next: - return None - - return self._parse_withisolatedloading() - - def _parse_procedure_option(self) -> exp.Expression | None: - if self._match_text_seq("EXECUTE", "AS"): - return self.expression( - exp.ExecuteAsProperty, - this=self._parse_var_from_options(self.EXECUTE_AS_OPTIONS, raise_unmatched=False) - or self._parse_string(), - ) - - return self._parse_var_from_options(self.PROCEDURE_OPTIONS) - - # https://dev.mysql.com/doc/refman/8.0/en/create-view.html - def _parse_definer(self) -> t.Optional[exp.DefinerProperty]: - self._match(TokenType.EQ) - - user = self._parse_id_var() - self._match(TokenType.PARAMETER) - host = self._parse_id_var() or (self._match(TokenType.MOD) and self._prev.text) - - if not user or not host: - return None - - return exp.DefinerProperty(this=f"{user}@{host}") - - def _parse_withjournaltable(self) -> exp.WithJournalTableProperty: - self._match(TokenType.TABLE) - self._match(TokenType.EQ) - return self.expression(exp.WithJournalTableProperty, this=self._parse_table_parts()) - - def _parse_log(self, no: bool = False) -> exp.LogProperty: - return self.expression(exp.LogProperty, no=no) - - def _parse_journal(self, **kwargs) -> exp.JournalProperty: - return self.expression(exp.JournalProperty, **kwargs) - - def _parse_checksum(self) -> exp.ChecksumProperty: - self._match(TokenType.EQ) - - on = None - if self._match(TokenType.ON): - on = True - elif self._match_text_seq("OFF"): - on = False - - return self.expression(exp.ChecksumProperty, on=on, default=self._match(TokenType.DEFAULT)) - - def _parse_cluster(self, wrapped: bool = False) -> exp.Cluster: - return self.expression( - exp.Cluster, - expressions=( - self._parse_wrapped_csv(self._parse_ordered) - if wrapped - else self._parse_csv(self._parse_ordered) - ), - ) - - def _parse_clustered_by(self) -> exp.ClusteredByProperty: - self._match_text_seq("BY") - - self._match_l_paren() - expressions = self._parse_csv(self._parse_column) - self._match_r_paren() - - if self._match_text_seq("SORTED", "BY"): - self._match_l_paren() - sorted_by = self._parse_csv(self._parse_ordered) - self._match_r_paren() - else: - sorted_by = None - - self._match(TokenType.INTO) - buckets = self._parse_number() - self._match_text_seq("BUCKETS") - - return self.expression( - exp.ClusteredByProperty, - expressions=expressions, - sorted_by=sorted_by, - buckets=buckets, - ) - - def _parse_copy_property(self) -> t.Optional[exp.CopyGrantsProperty]: - if not self._match_text_seq("GRANTS"): - self._retreat(self._index - 1) - return None - - return self.expression(exp.CopyGrantsProperty) - - def _parse_freespace(self) -> exp.FreespaceProperty: - self._match(TokenType.EQ) - return self.expression( - exp.FreespaceProperty, this=self._parse_number(), percent=self._match(TokenType.PERCENT) - ) - - def _parse_mergeblockratio( - self, no: bool = False, default: bool = False - ) -> exp.MergeBlockRatioProperty: - if self._match(TokenType.EQ): - return self.expression( - exp.MergeBlockRatioProperty, - this=self._parse_number(), - percent=self._match(TokenType.PERCENT), - ) - - return self.expression(exp.MergeBlockRatioProperty, no=no, default=default) - - def _parse_datablocksize( - self, - default: t.Optional[bool] = None, - minimum: t.Optional[bool] = None, - maximum: t.Optional[bool] = None, - ) -> exp.DataBlocksizeProperty: - self._match(TokenType.EQ) - size = self._parse_number() - - units = None - if self._match_texts(("BYTES", "KBYTES", "KILOBYTES")): - units = self._prev.text - - return self.expression( - exp.DataBlocksizeProperty, - size=size, - units=units, - default=default, - minimum=minimum, - maximum=maximum, - ) - - def _parse_blockcompression(self) -> exp.BlockCompressionProperty: - self._match(TokenType.EQ) - always = self._match_text_seq("ALWAYS") - manual = self._match_text_seq("MANUAL") - never = self._match_text_seq("NEVER") - default = self._match_text_seq("DEFAULT") - - autotemp = None - if self._match_text_seq("AUTOTEMP"): - autotemp = self._parse_schema() - - return self.expression( - exp.BlockCompressionProperty, - always=always, - manual=manual, - never=never, - default=default, - autotemp=autotemp, - ) - - def _parse_withisolatedloading(self) -> t.Optional[exp.IsolatedLoadingProperty]: - index = self._index - no = self._match_text_seq("NO") - concurrent = self._match_text_seq("CONCURRENT") - - if not self._match_text_seq("ISOLATED", "LOADING"): - self._retreat(index) - return None - - target = self._parse_var_from_options(self.ISOLATED_LOADING_OPTIONS, raise_unmatched=False) - return self.expression( - exp.IsolatedLoadingProperty, no=no, concurrent=concurrent, target=target - ) - - def _parse_locking(self) -> exp.LockingProperty: - if self._match(TokenType.TABLE): - kind = "TABLE" - elif self._match(TokenType.VIEW): - kind = "VIEW" - elif self._match(TokenType.ROW): - kind = "ROW" - elif self._match_text_seq("DATABASE"): - kind = "DATABASE" - else: - kind = None - - if kind in ("DATABASE", "TABLE", "VIEW"): - this = self._parse_table_parts() - else: - this = None - - if self._match(TokenType.FOR): - for_or_in = "FOR" - elif self._match(TokenType.IN): - for_or_in = "IN" - else: - for_or_in = None - - if self._match_text_seq("ACCESS"): - lock_type = "ACCESS" - elif self._match_texts(("EXCL", "EXCLUSIVE")): - lock_type = "EXCLUSIVE" - elif self._match_text_seq("SHARE"): - lock_type = "SHARE" - elif self._match_text_seq("READ"): - lock_type = "READ" - elif self._match_text_seq("WRITE"): - lock_type = "WRITE" - elif self._match_text_seq("CHECKSUM"): - lock_type = "CHECKSUM" - else: - lock_type = None - - override = self._match_text_seq("OVERRIDE") - - return self.expression( - exp.LockingProperty, - this=this, - kind=kind, - for_or_in=for_or_in, - lock_type=lock_type, - override=override, - ) - - def _parse_partition_by(self) -> t.List[exp.Expression]: - if self._match(TokenType.PARTITION_BY): - return self._parse_csv(self._parse_assignment) - return [] - - def _parse_partition_bound_spec(self) -> exp.PartitionBoundSpec: - def _parse_partition_bound_expr() -> t.Optional[exp.Expression]: - if self._match_text_seq("MINVALUE"): - return exp.var("MINVALUE") - if self._match_text_seq("MAXVALUE"): - return exp.var("MAXVALUE") - return self._parse_bitwise() - - this: t.Optional[exp.Expression | t.List[exp.Expression]] = None - expression = None - from_expressions = None - to_expressions = None - - if self._match(TokenType.IN): - this = self._parse_wrapped_csv(self._parse_bitwise) - elif self._match(TokenType.FROM): - from_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr) - self._match_text_seq("TO") - to_expressions = self._parse_wrapped_csv(_parse_partition_bound_expr) - elif self._match_text_seq("WITH", "(", "MODULUS"): - this = self._parse_number() - self._match_text_seq(",", "REMAINDER") - expression = self._parse_number() - self._match_r_paren() - else: - self.raise_error("Failed to parse partition bound spec.") - - return self.expression( - exp.PartitionBoundSpec, - this=this, - expression=expression, - from_expressions=from_expressions, - to_expressions=to_expressions, - ) - - # https://www.postgresql.org/docs/current/sql-createtable.html - def _parse_partitioned_of(self) -> t.Optional[exp.PartitionedOfProperty]: - if not self._match_text_seq("OF"): - self._retreat(self._index - 1) - return None - - this = self._parse_table(schema=True) - - if self._match(TokenType.DEFAULT): - expression: exp.Var | exp.PartitionBoundSpec = exp.var("DEFAULT") - elif self._match_text_seq("FOR", "VALUES"): - expression = self._parse_partition_bound_spec() - else: - self.raise_error("Expecting either DEFAULT or FOR VALUES clause.") - - return self.expression(exp.PartitionedOfProperty, this=this, expression=expression) - - def _parse_partitioned_by(self) -> exp.PartitionedByProperty: - self._match(TokenType.EQ) - return self.expression( - exp.PartitionedByProperty, - this=self._parse_schema() or self._parse_bracket(self._parse_field()), - ) - - def _parse_withdata(self, no: bool = False) -> exp.WithDataProperty: - if self._match_text_seq("AND", "STATISTICS"): - statistics = True - elif self._match_text_seq("AND", "NO", "STATISTICS"): - statistics = False - else: - statistics = None - - return self.expression(exp.WithDataProperty, no=no, statistics=statistics) - - def _parse_contains_property(self) -> t.Optional[exp.SqlReadWriteProperty]: - if self._match_text_seq("SQL"): - return self.expression(exp.SqlReadWriteProperty, this="CONTAINS SQL") - return None - - def _parse_modifies_property(self) -> t.Optional[exp.SqlReadWriteProperty]: - if self._match_text_seq("SQL", "DATA"): - return self.expression(exp.SqlReadWriteProperty, this="MODIFIES SQL DATA") - return None - - def _parse_no_property(self) -> t.Optional[exp.Expression]: - if self._match_text_seq("PRIMARY", "INDEX"): - return exp.NoPrimaryIndexProperty() - if self._match_text_seq("SQL"): - return self.expression(exp.SqlReadWriteProperty, this="NO SQL") - return None - - def _parse_on_property(self) -> t.Optional[exp.Expression]: - if self._match_text_seq("COMMIT", "PRESERVE", "ROWS"): - return exp.OnCommitProperty() - if self._match_text_seq("COMMIT", "DELETE", "ROWS"): - return exp.OnCommitProperty(delete=True) - return self.expression(exp.OnProperty, this=self._parse_schema(self._parse_id_var())) - - def _parse_reads_property(self) -> t.Optional[exp.SqlReadWriteProperty]: - if self._match_text_seq("SQL", "DATA"): - return self.expression(exp.SqlReadWriteProperty, this="READS SQL DATA") - return None - - def _parse_distkey(self) -> exp.DistKeyProperty: - return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) - - def _parse_create_like(self) -> t.Optional[exp.LikeProperty]: - table = self._parse_table(schema=True) - - options = [] - while self._match_texts(("INCLUDING", "EXCLUDING")): - this = self._prev.text.upper() - - id_var = self._parse_id_var() - if not id_var: - return None - - options.append( - self.expression(exp.Property, this=this, value=exp.var(id_var.this.upper())) - ) - - return self.expression(exp.LikeProperty, this=table, expressions=options) - - def _parse_sortkey(self, compound: bool = False) -> exp.SortKeyProperty: - return self.expression( - exp.SortKeyProperty, this=self._parse_wrapped_id_vars(), compound=compound - ) - - def _parse_character_set(self, default: bool = False) -> exp.CharacterSetProperty: - self._match(TokenType.EQ) - return self.expression( - exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default - ) - - def _parse_remote_with_connection(self) -> exp.RemoteWithConnectionModelProperty: - self._match_text_seq("WITH", "CONNECTION") - return self.expression( - exp.RemoteWithConnectionModelProperty, this=self._parse_table_parts() - ) - - def _parse_returns(self) -> exp.ReturnsProperty: - value: t.Optional[exp.Expression] - null = None - is_table = self._match(TokenType.TABLE) - - if is_table: - if self._match(TokenType.LT): - value = self.expression( - exp.Schema, - this="TABLE", - expressions=self._parse_csv(self._parse_struct_types), - ) - if not self._match(TokenType.GT): - self.raise_error("Expecting >") - else: - value = self._parse_schema(exp.var("TABLE")) - elif self._match_text_seq("NULL", "ON", "NULL", "INPUT"): - null = True - value = None - else: - value = self._parse_types() - - return self.expression(exp.ReturnsProperty, this=value, is_table=is_table, null=null) - - def _parse_describe(self) -> exp.Describe: - kind = self._match_set(self.CREATABLES) and self._prev.text - style = self._match_texts(self.DESCRIBE_STYLES) and self._prev.text.upper() - if self._match(TokenType.DOT): - style = None - self._retreat(self._index - 2) - - format = self._parse_property() if self._match(TokenType.FORMAT, advance=False) else None - - if self._match_set(self.STATEMENT_PARSERS, advance=False): - this = self._parse_statement() - else: - this = self._parse_table(schema=True) - - properties = self._parse_properties() - expressions = properties.expressions if properties else None - partition = self._parse_partition() - return self.expression( - exp.Describe, - this=this, - style=style, - kind=kind, - expressions=expressions, - partition=partition, - format=format, - ) - - def _parse_multitable_inserts(self, comments: t.Optional[t.List[str]]) -> exp.MultitableInserts: - kind = self._prev.text.upper() - expressions = [] - - def parse_conditional_insert() -> t.Optional[exp.ConditionalInsert]: - if self._match(TokenType.WHEN): - expression = self._parse_disjunction() - self._match(TokenType.THEN) - else: - expression = None - - else_ = self._match(TokenType.ELSE) - - if not self._match(TokenType.INTO): - return None - - return self.expression( - exp.ConditionalInsert, - this=self.expression( - exp.Insert, - this=self._parse_table(schema=True), - expression=self._parse_derived_table_values(), - ), - expression=expression, - else_=else_, - ) - - expression = parse_conditional_insert() - while expression is not None: - expressions.append(expression) - expression = parse_conditional_insert() - - return self.expression( - exp.MultitableInserts, - kind=kind, - comments=comments, - expressions=expressions, - source=self._parse_table(), - ) - - def _parse_insert(self) -> t.Union[exp.Insert, exp.MultitableInserts]: - comments = [] - hint = self._parse_hint() - overwrite = self._match(TokenType.OVERWRITE) - ignore = self._match(TokenType.IGNORE) - local = self._match_text_seq("LOCAL") - alternative = None - is_function = None - - if self._match_text_seq("DIRECTORY"): - this: t.Optional[exp.Expression] = self.expression( - exp.Directory, - this=self._parse_var_or_string(), - local=local, - row_format=self._parse_row_format(match_row=True), - ) - else: - if self._match_set((TokenType.FIRST, TokenType.ALL)): - comments += ensure_list(self._prev_comments) - return self._parse_multitable_inserts(comments) - - if self._match(TokenType.OR): - alternative = self._match_texts(self.INSERT_ALTERNATIVES) and self._prev.text - - self._match(TokenType.INTO) - comments += ensure_list(self._prev_comments) - self._match(TokenType.TABLE) - is_function = self._match(TokenType.FUNCTION) - - this = ( - self._parse_table(schema=True, parse_partition=True) - if not is_function - else self._parse_function() - ) - if isinstance(this, exp.Table) and self._match(TokenType.ALIAS, advance=False): - this.set("alias", self._parse_table_alias()) - - returning = self._parse_returning() - - return self.expression( - exp.Insert, - comments=comments, - hint=hint, - is_function=is_function, - this=this, - stored=self._match_text_seq("STORED") and self._parse_stored(), - by_name=self._match_text_seq("BY", "NAME"), - exists=self._parse_exists(), - where=self._match_pair(TokenType.REPLACE, TokenType.WHERE) and self._parse_assignment(), - partition=self._match(TokenType.PARTITION_BY) and self._parse_partitioned_by(), - settings=self._match_text_seq("SETTINGS") and self._parse_settings_property(), - expression=self._parse_derived_table_values() or self._parse_ddl_select(), - conflict=self._parse_on_conflict(), - returning=returning or self._parse_returning(), - overwrite=overwrite, - alternative=alternative, - ignore=ignore, - source=self._match(TokenType.TABLE) and self._parse_table(), - ) - - def _parse_kill(self) -> exp.Kill: - kind = exp.var(self._prev.text) if self._match_texts(("CONNECTION", "QUERY")) else None - - return self.expression( - exp.Kill, - this=self._parse_primary(), - kind=kind, - ) - - def _parse_on_conflict(self) -> t.Optional[exp.OnConflict]: - conflict = self._match_text_seq("ON", "CONFLICT") - duplicate = self._match_text_seq("ON", "DUPLICATE", "KEY") - - if not conflict and not duplicate: - return None - - conflict_keys = None - constraint = None - - if conflict: - if self._match_text_seq("ON", "CONSTRAINT"): - constraint = self._parse_id_var() - elif self._match(TokenType.L_PAREN): - conflict_keys = self._parse_csv(self._parse_id_var) - self._match_r_paren() - - action = self._parse_var_from_options(self.CONFLICT_ACTIONS) - if self._prev.token_type == TokenType.UPDATE: - self._match(TokenType.SET) - expressions = self._parse_csv(self._parse_equality) - else: - expressions = None - - return self.expression( - exp.OnConflict, - duplicate=duplicate, - expressions=expressions, - action=action, - conflict_keys=conflict_keys, - constraint=constraint, - where=self._parse_where(), - ) - - def _parse_returning(self) -> t.Optional[exp.Returning]: - if not self._match(TokenType.RETURNING): - return None - return self.expression( - exp.Returning, - expressions=self._parse_csv(self._parse_expression), - into=self._match(TokenType.INTO) and self._parse_table_part(), - ) - - def _parse_row(self) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]: - if not self._match(TokenType.FORMAT): - return None - return self._parse_row_format() - - def _parse_serde_properties(self, with_: bool = False) -> t.Optional[exp.SerdeProperties]: - index = self._index - with_ = with_ or self._match_text_seq("WITH") - - if not self._match(TokenType.SERDE_PROPERTIES): - self._retreat(index) - return None - return self.expression( - exp.SerdeProperties, - **{ # type: ignore - "expressions": self._parse_wrapped_properties(), - "with": with_, - }, - ) - - def _parse_row_format( - self, match_row: bool = False - ) -> t.Optional[exp.RowFormatSerdeProperty | exp.RowFormatDelimitedProperty]: - if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT): - return None - - if self._match_text_seq("SERDE"): - this = self._parse_string() - - serde_properties = self._parse_serde_properties() - - return self.expression( - exp.RowFormatSerdeProperty, this=this, serde_properties=serde_properties - ) - - self._match_text_seq("DELIMITED") - - kwargs = {} - - if self._match_text_seq("FIELDS", "TERMINATED", "BY"): - kwargs["fields"] = self._parse_string() - if self._match_text_seq("ESCAPED", "BY"): - kwargs["escaped"] = self._parse_string() - if self._match_text_seq("COLLECTION", "ITEMS", "TERMINATED", "BY"): - kwargs["collection_items"] = self._parse_string() - if self._match_text_seq("MAP", "KEYS", "TERMINATED", "BY"): - kwargs["map_keys"] = self._parse_string() - if self._match_text_seq("LINES", "TERMINATED", "BY"): - kwargs["lines"] = self._parse_string() - if self._match_text_seq("NULL", "DEFINED", "AS"): - kwargs["null"] = self._parse_string() - - return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore - - def _parse_load(self) -> exp.LoadData | exp.Command: - if self._match_text_seq("DATA"): - local = self._match_text_seq("LOCAL") - self._match_text_seq("INPATH") - inpath = self._parse_string() - overwrite = self._match(TokenType.OVERWRITE) - self._match_pair(TokenType.INTO, TokenType.TABLE) - - return self.expression( - exp.LoadData, - this=self._parse_table(schema=True), - local=local, - overwrite=overwrite, - inpath=inpath, - partition=self._parse_partition(), - input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(), - serde=self._match_text_seq("SERDE") and self._parse_string(), - ) - return self._parse_as_command(self._prev) - - def _parse_delete(self) -> exp.Delete: - # This handles MySQL's "Multiple-Table Syntax" - # https://dev.mysql.com/doc/refman/8.0/en/delete.html - tables = None - if not self._match(TokenType.FROM, advance=False): - tables = self._parse_csv(self._parse_table) or None - - returning = self._parse_returning() - - return self.expression( - exp.Delete, - tables=tables, - this=self._match(TokenType.FROM) and self._parse_table(joins=True), - using=self._match(TokenType.USING) and self._parse_table(joins=True), - cluster=self._match(TokenType.ON) and self._parse_on_property(), - where=self._parse_where(), - returning=returning or self._parse_returning(), - limit=self._parse_limit(), - ) - - def _parse_update(self) -> exp.Update: - this = self._parse_table(joins=True, alias_tokens=self.UPDATE_ALIAS_TOKENS) - expressions = self._match(TokenType.SET) and self._parse_csv(self._parse_equality) - returning = self._parse_returning() - return self.expression( - exp.Update, - **{ # type: ignore - "this": this, - "expressions": expressions, - "from": self._parse_from(joins=True), - "where": self._parse_where(), - "returning": returning or self._parse_returning(), - "order": self._parse_order(), - "limit": self._parse_limit(), - }, - ) - - def _parse_use(self) -> exp.Use: - return self.expression( - exp.Use, - kind=self._parse_var_from_options(self.USABLES, raise_unmatched=False), - this=self._parse_table(schema=False), - ) - - def _parse_uncache(self) -> exp.Uncache: - if not self._match(TokenType.TABLE): - self.raise_error("Expecting TABLE after UNCACHE") - - return self.expression( - exp.Uncache, exists=self._parse_exists(), this=self._parse_table(schema=True) - ) - - def _parse_cache(self) -> exp.Cache: - lazy = self._match_text_seq("LAZY") - self._match(TokenType.TABLE) - table = self._parse_table(schema=True) - - options = [] - if self._match_text_seq("OPTIONS"): - self._match_l_paren() - k = self._parse_string() - self._match(TokenType.EQ) - v = self._parse_string() - options = [k, v] - self._match_r_paren() - - self._match(TokenType.ALIAS) - return self.expression( - exp.Cache, - this=table, - lazy=lazy, - options=options, - expression=self._parse_select(nested=True), - ) - - def _parse_partition(self) -> t.Optional[exp.Partition]: - if not self._match_texts(self.PARTITION_KEYWORDS): - return None - - return self.expression( - exp.Partition, - subpartition=self._prev.text.upper() == "SUBPARTITION", - expressions=self._parse_wrapped_csv(self._parse_assignment), - ) - - def _parse_value(self, values: bool = True) -> t.Optional[exp.Tuple]: - def _parse_value_expression() -> t.Optional[exp.Expression]: - if self.dialect.SUPPORTS_VALUES_DEFAULT and self._match(TokenType.DEFAULT): - return exp.var(self._prev.text.upper()) - return self._parse_expression() - - if self._match(TokenType.L_PAREN): - expressions = self._parse_csv(_parse_value_expression) - self._match_r_paren() - return self.expression(exp.Tuple, expressions=expressions) - - # In some dialects we can have VALUES 1, 2 which results in 1 column & 2 rows. - expression = self._parse_expression() - if expression: - return self.expression(exp.Tuple, expressions=[expression]) - return None - - def _parse_projections(self) -> t.List[exp.Expression]: - return self._parse_expressions() - - def _parse_wrapped_select(self, table: bool = False) -> t.Optional[exp.Expression]: - if self._match_set((TokenType.PIVOT, TokenType.UNPIVOT)): - this: t.Optional[exp.Expression] = self._parse_simplified_pivot( - is_unpivot=self._prev.token_type == TokenType.UNPIVOT - ) - elif self._match(TokenType.FROM): - from_ = self._parse_from(skip_from_token=True) - # Support parentheses for duckdb FROM-first syntax - select = self._parse_select() - if select: - select.set("from", from_) - this = select - else: - this = exp.select("*").from_(t.cast(exp.From, from_)) - else: - this = ( - self._parse_table() - if table - else self._parse_select(nested=True, parse_set_operation=False) - ) - - # Transform exp.Values into a exp.Table to pass through parse_query_modifiers - # in case a modifier (e.g. join) is following - if table and isinstance(this, exp.Values) and this.alias: - alias = this.args["alias"].pop() - this = exp.Table(this=this, alias=alias) - - this = self._parse_query_modifiers(self._parse_set_operations(this)) - - return this - - def _parse_select( - self, - nested: bool = False, - table: bool = False, - parse_subquery_alias: bool = True, - parse_set_operation: bool = True, - ) -> t.Optional[exp.Expression]: - cte = self._parse_with() - - if cte: - this = self._parse_statement() - - if not this: - self.raise_error("Failed to parse any statement following CTE") - return cte - - if "with" in this.arg_types: - this.set("with", cte) - else: - self.raise_error(f"{this.key} does not support CTE") - this = cte - - return this - - # duckdb supports leading with FROM x - from_ = self._parse_from() if self._match(TokenType.FROM, advance=False) else None - - if self._match(TokenType.SELECT): - comments = self._prev_comments - - hint = self._parse_hint() - - if self._next and not self._next.token_type == TokenType.DOT: - all_ = self._match(TokenType.ALL) - distinct = self._match_set(self.DISTINCT_TOKENS) - else: - all_, distinct = None, None - - kind = ( - self._match(TokenType.ALIAS) - and self._match_texts(("STRUCT", "VALUE")) - and self._prev.text.upper() - ) - - if distinct: - distinct = self.expression( - exp.Distinct, - on=self._parse_value(values=False) if self._match(TokenType.ON) else None, - ) - - if all_ and distinct: - self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") - - operation_modifiers = [] - while self._curr and self._match_texts(self.OPERATION_MODIFIERS): - operation_modifiers.append(exp.var(self._prev.text.upper())) - - limit = self._parse_limit(top=True) - projections = self._parse_projections() - - this = self.expression( - exp.Select, - kind=kind, - hint=hint, - distinct=distinct, - expressions=projections, - limit=limit, - operation_modifiers=operation_modifiers or None, - ) - this.comments = comments - - into = self._parse_into() - if into: - this.set("into", into) - - if not from_: - from_ = self._parse_from() - - if from_: - this.set("from", from_) - - this = self._parse_query_modifiers(this) - elif (table or nested) and self._match(TokenType.L_PAREN): - this = self._parse_wrapped_select(table=table) - - # We return early here so that the UNION isn't attached to the subquery by the - # following call to _parse_set_operations, but instead becomes the parent node - self._match_r_paren() - return self._parse_subquery(this, parse_alias=parse_subquery_alias) - elif self._match(TokenType.VALUES, advance=False): - this = self._parse_derived_table_values() - elif from_: - this = exp.select("*").from_(from_.this, copy=False) - if self._match(TokenType.PIPE_GT, advance=False): - return self._parse_pipe_syntax_query(this) - elif self._match(TokenType.SUMMARIZE): - table = self._match(TokenType.TABLE) - this = self._parse_select() or self._parse_string() or self._parse_table() - return self.expression(exp.Summarize, this=this, table=table) - elif self._match(TokenType.DESCRIBE): - this = self._parse_describe() - elif self._match_text_seq("STREAM"): - this = self._parse_function() - if this: - this = self.expression(exp.Stream, this=this) - else: - self._retreat(self._index - 1) - else: - this = None - - return self._parse_set_operations(this) if parse_set_operation else this - - def _parse_recursive_with_search(self) -> t.Optional[exp.RecursiveWithSearch]: - self._match_text_seq("SEARCH") - - kind = self._match_texts(self.RECURSIVE_CTE_SEARCH_KIND) and self._prev.text.upper() - - if not kind: - return None - - self._match_text_seq("FIRST", "BY") - - return self.expression( - exp.RecursiveWithSearch, - kind=kind, - this=self._parse_id_var(), - expression=self._match_text_seq("SET") and self._parse_id_var(), - using=self._match_text_seq("USING") and self._parse_id_var(), - ) - - def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.With]: - if not skip_with_token and not self._match(TokenType.WITH): - return None - - comments = self._prev_comments - recursive = self._match(TokenType.RECURSIVE) - - last_comments = None - expressions = [] - while True: - cte = self._parse_cte() - if isinstance(cte, exp.CTE): - expressions.append(cte) - if last_comments: - cte.add_comments(last_comments) - - if not self._match(TokenType.COMMA) and not self._match(TokenType.WITH): - break - else: - self._match(TokenType.WITH) - - last_comments = self._prev_comments - - return self.expression( - exp.With, - comments=comments, - expressions=expressions, - recursive=recursive, - search=self._parse_recursive_with_search(), - ) - - def _parse_cte(self) -> t.Optional[exp.CTE]: - index = self._index - - alias = self._parse_table_alias(self.ID_VAR_TOKENS) - if not alias or not alias.this: - self.raise_error("Expected CTE to have alias") - - if not self._match(TokenType.ALIAS) and not self.OPTIONAL_ALIAS_TOKEN_CTE: - self._retreat(index) - return None - - comments = self._prev_comments - - if self._match_text_seq("NOT", "MATERIALIZED"): - materialized = False - elif self._match_text_seq("MATERIALIZED"): - materialized = True - else: - materialized = None - - cte = self.expression( - exp.CTE, - this=self._parse_wrapped(self._parse_statement), - alias=alias, - materialized=materialized, - comments=comments, - ) - - if isinstance(cte.this, exp.Values): - cte.set("this", exp.select("*").from_(exp.alias_(cte.this, "_values", table=True))) - - return cte - - def _parse_table_alias( - self, alias_tokens: t.Optional[t.Collection[TokenType]] = None - ) -> t.Optional[exp.TableAlias]: - # In some dialects, LIMIT and OFFSET can act as both identifiers and keywords (clauses) - # so this section tries to parse the clause version and if it fails, it treats the token - # as an identifier (alias) - if self._can_parse_limit_or_offset(): - return None - - any_token = self._match(TokenType.ALIAS) - alias = ( - self._parse_id_var(any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) - or self._parse_string_as_identifier() - ) - - index = self._index - if self._match(TokenType.L_PAREN): - columns = self._parse_csv(self._parse_function_parameter) - self._match_r_paren() if columns else self._retreat(index) - else: - columns = None - - if not alias and not columns: - return None - - table_alias = self.expression(exp.TableAlias, this=alias, columns=columns) - - # We bubble up comments from the Identifier to the TableAlias - if isinstance(alias, exp.Identifier): - table_alias.add_comments(alias.pop_comments()) - - return table_alias - - def _parse_subquery( - self, this: t.Optional[exp.Expression], parse_alias: bool = True - ) -> t.Optional[exp.Subquery]: - if not this: - return None - - return self.expression( - exp.Subquery, - this=this, - pivots=self._parse_pivots(), - alias=self._parse_table_alias() if parse_alias else None, - sample=self._parse_table_sample(), - ) - - def _implicit_unnests_to_explicit(self, this: E) -> E: - from sqlglot.optimizer.normalize_identifiers import normalize_identifiers as _norm - - refs = {_norm(this.args["from"].this.copy(), dialect=self.dialect).alias_or_name} - for i, join in enumerate(this.args.get("joins") or []): - table = join.this - normalized_table = table.copy() - normalized_table.meta["maybe_column"] = True - normalized_table = _norm(normalized_table, dialect=self.dialect) - - if isinstance(table, exp.Table) and not join.args.get("on"): - if normalized_table.parts[0].name in refs: - table_as_column = table.to_column() - unnest = exp.Unnest(expressions=[table_as_column]) - - # Table.to_column creates a parent Alias node that we want to convert to - # a TableAlias and attach to the Unnest, so it matches the parser's output - if isinstance(table.args.get("alias"), exp.TableAlias): - table_as_column.replace(table_as_column.this) - exp.alias_(unnest, None, table=[table.args["alias"].this], copy=False) - - table.replace(unnest) - - refs.add(normalized_table.alias_or_name) - - return this - - def _parse_query_modifiers( - self, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: - if isinstance(this, self.MODIFIABLES): - for join in self._parse_joins(): - this.append("joins", join) - for lateral in iter(self._parse_lateral, None): - this.append("laterals", lateral) - - while True: - if self._match_set(self.QUERY_MODIFIER_PARSERS, advance=False): - parser = self.QUERY_MODIFIER_PARSERS[self._curr.token_type] - key, expression = parser(self) - - if expression: - this.set(key, expression) - if key == "limit": - offset = expression.args.pop("offset", None) - - if offset: - offset = exp.Offset(expression=offset) - this.set("offset", offset) - - limit_by_expressions = expression.expressions - expression.set("expressions", None) - offset.set("expressions", limit_by_expressions) - continue - break - - if self.SUPPORTS_IMPLICIT_UNNEST and this and this.args.get("from"): - this = self._implicit_unnests_to_explicit(this) - - return this - - def _parse_hint_fallback_to_string(self) -> t.Optional[exp.Hint]: - start = self._curr - while self._curr: - self._advance() - - end = self._tokens[self._index - 1] - return exp.Hint(expressions=[self._find_sql(start, end)]) - - def _parse_hint_function_call(self) -> t.Optional[exp.Expression]: - return self._parse_function_call() - - def _parse_hint_body(self) -> t.Optional[exp.Hint]: - start_index = self._index - should_fallback_to_string = False - - hints = [] - try: - for hint in iter( - lambda: self._parse_csv( - lambda: self._parse_hint_function_call() or self._parse_var(upper=True), - ), - [], - ): - hints.extend(hint) - except ParseError: - should_fallback_to_string = True - - if should_fallback_to_string or self._curr: - self._retreat(start_index) - return self._parse_hint_fallback_to_string() - - return self.expression(exp.Hint, expressions=hints) - - def _parse_hint(self) -> t.Optional[exp.Hint]: - if self._match(TokenType.HINT) and self._prev_comments: - return exp.maybe_parse(self._prev_comments[0], into=exp.Hint, dialect=self.dialect) - - return None - - def _parse_into(self) -> t.Optional[exp.Into]: - if not self._match(TokenType.INTO): - return None - - temp = self._match(TokenType.TEMPORARY) - unlogged = self._match_text_seq("UNLOGGED") - self._match(TokenType.TABLE) - - return self.expression( - exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged - ) - - def _parse_from( - self, joins: bool = False, skip_from_token: bool = False - ) -> t.Optional[exp.From]: - if not skip_from_token and not self._match(TokenType.FROM): - return None - - return self.expression( - exp.From, comments=self._prev_comments, this=self._parse_table(joins=joins) - ) - - def _parse_match_recognize_measure(self) -> exp.MatchRecognizeMeasure: - return self.expression( - exp.MatchRecognizeMeasure, - window_frame=self._match_texts(("FINAL", "RUNNING")) and self._prev.text.upper(), - this=self._parse_expression(), - ) - - def _parse_match_recognize(self) -> t.Optional[exp.MatchRecognize]: - if not self._match(TokenType.MATCH_RECOGNIZE): - return None - - self._match_l_paren() - - partition = self._parse_partition_by() - order = self._parse_order() - - measures = ( - self._parse_csv(self._parse_match_recognize_measure) - if self._match_text_seq("MEASURES") - else None - ) - - if self._match_text_seq("ONE", "ROW", "PER", "MATCH"): - rows = exp.var("ONE ROW PER MATCH") - elif self._match_text_seq("ALL", "ROWS", "PER", "MATCH"): - text = "ALL ROWS PER MATCH" - if self._match_text_seq("SHOW", "EMPTY", "MATCHES"): - text += " SHOW EMPTY MATCHES" - elif self._match_text_seq("OMIT", "EMPTY", "MATCHES"): - text += " OMIT EMPTY MATCHES" - elif self._match_text_seq("WITH", "UNMATCHED", "ROWS"): - text += " WITH UNMATCHED ROWS" - rows = exp.var(text) - else: - rows = None - - if self._match_text_seq("AFTER", "MATCH", "SKIP"): - text = "AFTER MATCH SKIP" - if self._match_text_seq("PAST", "LAST", "ROW"): - text += " PAST LAST ROW" - elif self._match_text_seq("TO", "NEXT", "ROW"): - text += " TO NEXT ROW" - elif self._match_text_seq("TO", "FIRST"): - text += f" TO FIRST {self._advance_any().text}" # type: ignore - elif self._match_text_seq("TO", "LAST"): - text += f" TO LAST {self._advance_any().text}" # type: ignore - after = exp.var(text) - else: - after = None - - if self._match_text_seq("PATTERN"): - self._match_l_paren() - - if not self._curr: - self.raise_error("Expecting )", self._curr) - - paren = 1 - start = self._curr - - while self._curr and paren > 0: - if self._curr.token_type == TokenType.L_PAREN: - paren += 1 - if self._curr.token_type == TokenType.R_PAREN: - paren -= 1 - - end = self._prev - self._advance() - - if paren > 0: - self.raise_error("Expecting )", self._curr) - - pattern = exp.var(self._find_sql(start, end)) - else: - pattern = None - - define = ( - self._parse_csv(self._parse_name_as_expression) - if self._match_text_seq("DEFINE") - else None - ) - - self._match_r_paren() - - return self.expression( - exp.MatchRecognize, - partition_by=partition, - order=order, - measures=measures, - rows=rows, - after=after, - pattern=pattern, - define=define, - alias=self._parse_table_alias(), - ) - - def _parse_lateral(self) -> t.Optional[exp.Lateral]: - cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) - if not cross_apply and self._match_pair(TokenType.OUTER, TokenType.APPLY): - cross_apply = False - - if cross_apply is not None: - this = self._parse_select(table=True) - view = None - outer = None - elif self._match(TokenType.LATERAL): - this = self._parse_select(table=True) - view = self._match(TokenType.VIEW) - outer = self._match(TokenType.OUTER) - else: - return None - - if not this: - this = ( - self._parse_unnest() - or self._parse_function() - or self._parse_id_var(any_token=False) - ) - - while self._match(TokenType.DOT): - this = exp.Dot( - this=this, - expression=self._parse_function() or self._parse_id_var(any_token=False), - ) - - ordinality: t.Optional[bool] = None - - if view: - table = self._parse_id_var(any_token=False) - columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else [] - table_alias: t.Optional[exp.TableAlias] = self.expression( - exp.TableAlias, this=table, columns=columns - ) - elif isinstance(this, (exp.Subquery, exp.Unnest)) and this.alias: - # We move the alias from the lateral's child node to the lateral itself - table_alias = this.args["alias"].pop() - else: - ordinality = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) - table_alias = self._parse_table_alias() - - return self.expression( - exp.Lateral, - this=this, - view=view, - outer=outer, - alias=table_alias, - cross_apply=cross_apply, - ordinality=ordinality, - ) - - def _parse_join_parts( - self, - ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: - return ( - self._match_set(self.JOIN_METHODS) and self._prev, - self._match_set(self.JOIN_SIDES) and self._prev, - self._match_set(self.JOIN_KINDS) and self._prev, - ) - - def _parse_using_identifiers(self) -> t.List[exp.Expression]: - def _parse_column_as_identifier() -> t.Optional[exp.Expression]: - this = self._parse_column() - if isinstance(this, exp.Column): - return this.this - return this - - return self._parse_wrapped_csv(_parse_column_as_identifier, optional=True) - - def _parse_join( - self, skip_join_token: bool = False, parse_bracket: bool = False - ) -> t.Optional[exp.Join]: - if self._match(TokenType.COMMA): - table = self._try_parse(self._parse_table) - if table: - return self.expression(exp.Join, this=table) - return None - - index = self._index - method, side, kind = self._parse_join_parts() - hint = self._prev.text if self._match_texts(self.JOIN_HINTS) else None - join = self._match(TokenType.JOIN) or (kind and kind.token_type == TokenType.STRAIGHT_JOIN) - - if not skip_join_token and not join: - self._retreat(index) - kind = None - method = None - side = None - - outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY, False) - cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY, False) - - if not skip_join_token and not join and not outer_apply and not cross_apply: - return None - - kwargs: t.Dict[str, t.Any] = {"this": self._parse_table(parse_bracket=parse_bracket)} - if kind and kind.token_type == TokenType.ARRAY and self._match(TokenType.COMMA): - kwargs["expressions"] = self._parse_csv( - lambda: self._parse_table(parse_bracket=parse_bracket) - ) - - if method: - kwargs["method"] = method.text - if side: - kwargs["side"] = side.text - if kind: - kwargs["kind"] = kind.text - if hint: - kwargs["hint"] = hint - - if self._match(TokenType.MATCH_CONDITION): - kwargs["match_condition"] = self._parse_wrapped(self._parse_comparison) - - if self._match(TokenType.ON): - kwargs["on"] = self._parse_assignment() - elif self._match(TokenType.USING): - kwargs["using"] = self._parse_using_identifiers() - elif ( - not (outer_apply or cross_apply) - and not isinstance(kwargs["this"], exp.Unnest) - and not (kind and kind.token_type in (TokenType.CROSS, TokenType.ARRAY)) - ): - index = self._index - joins: t.Optional[list] = list(self._parse_joins()) - - if joins and self._match(TokenType.ON): - kwargs["on"] = self._parse_assignment() - elif joins and self._match(TokenType.USING): - kwargs["using"] = self._parse_using_identifiers() - else: - joins = None - self._retreat(index) - - kwargs["this"].set("joins", joins if joins else None) - - kwargs["pivots"] = self._parse_pivots() - - comments = [c for token in (method, side, kind) if token for c in token.comments] - return self.expression(exp.Join, comments=comments, **kwargs) - - def _parse_opclass(self) -> t.Optional[exp.Expression]: - this = self._parse_assignment() - - if self._match_texts(self.OPCLASS_FOLLOW_KEYWORDS, advance=False): - return this - - if not self._match_set(self.OPTYPE_FOLLOW_TOKENS, advance=False): - return self.expression(exp.Opclass, this=this, expression=self._parse_table_parts()) - - return this - - def _parse_index_params(self) -> exp.IndexParameters: - using = self._parse_var(any_token=True) if self._match(TokenType.USING) else None - - if self._match(TokenType.L_PAREN, advance=False): - columns = self._parse_wrapped_csv(self._parse_with_operator) - else: - columns = None - - include = self._parse_wrapped_id_vars() if self._match_text_seq("INCLUDE") else None - partition_by = self._parse_partition_by() - with_storage = self._match(TokenType.WITH) and self._parse_wrapped_properties() - tablespace = ( - self._parse_var(any_token=True) - if self._match_text_seq("USING", "INDEX", "TABLESPACE") - else None - ) - where = self._parse_where() - - on = self._parse_field() if self._match(TokenType.ON) else None - - return self.expression( - exp.IndexParameters, - using=using, - columns=columns, - include=include, - partition_by=partition_by, - where=where, - with_storage=with_storage, - tablespace=tablespace, - on=on, - ) - - def _parse_index( - self, index: t.Optional[exp.Expression] = None, anonymous: bool = False - ) -> t.Optional[exp.Index]: - if index or anonymous: - unique = None - primary = None - amp = None - - self._match(TokenType.ON) - self._match(TokenType.TABLE) # hive - table = self._parse_table_parts(schema=True) - else: - unique = self._match(TokenType.UNIQUE) - primary = self._match_text_seq("PRIMARY") - amp = self._match_text_seq("AMP") - - if not self._match(TokenType.INDEX): - return None - - index = self._parse_id_var() - table = None - - params = self._parse_index_params() - - return self.expression( - exp.Index, - this=index, - table=table, - unique=unique, - primary=primary, - amp=amp, - params=params, - ) - - def _parse_table_hints(self) -> t.Optional[t.List[exp.Expression]]: - hints: t.List[exp.Expression] = [] - if self._match_pair(TokenType.WITH, TokenType.L_PAREN): - # https://learn.microsoft.com/en-us/sql/t-sql/queries/hints-transact-sql-table?view=sql-server-ver16 - hints.append( - self.expression( - exp.WithTableHint, - expressions=self._parse_csv( - lambda: self._parse_function() or self._parse_var(any_token=True) - ), - ) - ) - self._match_r_paren() - else: - # https://dev.mysql.com/doc/refman/8.0/en/index-hints.html - while self._match_set(self.TABLE_INDEX_HINT_TOKENS): - hint = exp.IndexTableHint(this=self._prev.text.upper()) - - self._match_set((TokenType.INDEX, TokenType.KEY)) - if self._match(TokenType.FOR): - hint.set("target", self._advance_any() and self._prev.text.upper()) - - hint.set("expressions", self._parse_wrapped_id_vars()) - hints.append(hint) - - return hints or None - - def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]: - return ( - (not schema and self._parse_function(optional_parens=False)) - or self._parse_id_var(any_token=False) - or self._parse_string_as_identifier() - or self._parse_placeholder() - ) - - def _parse_table_parts( - self, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False - ) -> exp.Table: - catalog = None - db = None - table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema) - - while self._match(TokenType.DOT): - if catalog: - # This allows nesting the table in arbitrarily many dot expressions if needed - table = self.expression( - exp.Dot, this=table, expression=self._parse_table_part(schema=schema) - ) - else: - catalog = db - db = table - # "" used for tsql FROM a..b case - table = self._parse_table_part(schema=schema) or "" - - if ( - wildcard - and self._is_connected() - and (isinstance(table, exp.Identifier) or not table) - and self._match(TokenType.STAR) - ): - if isinstance(table, exp.Identifier): - table.args["this"] += "*" - else: - table = exp.Identifier(this="*") - - # We bubble up comments from the Identifier to the Table - comments = table.pop_comments() if isinstance(table, exp.Expression) else None - - if is_db_reference: - catalog = db - db = table - table = None - - if not table and not is_db_reference: - self.raise_error(f"Expected table name but got {self._curr}") - if not db and is_db_reference: - self.raise_error(f"Expected database name but got {self._curr}") - - table = self.expression( - exp.Table, - comments=comments, - this=table, - db=db, - catalog=catalog, - ) - - changes = self._parse_changes() - if changes: - table.set("changes", changes) - - at_before = self._parse_historical_data() - if at_before: - table.set("when", at_before) - - pivots = self._parse_pivots() - if pivots: - table.set("pivots", pivots) - - return table - - def _parse_table( - self, - schema: bool = False, - joins: bool = False, - alias_tokens: t.Optional[t.Collection[TokenType]] = None, - parse_bracket: bool = False, - is_db_reference: bool = False, - parse_partition: bool = False, - ) -> t.Optional[exp.Expression]: - lateral = self._parse_lateral() - if lateral: - return lateral - - unnest = self._parse_unnest() - if unnest: - return unnest - - values = self._parse_derived_table_values() - if values: - return values - - subquery = self._parse_select(table=True) - if subquery: - if not subquery.args.get("pivots"): - subquery.set("pivots", self._parse_pivots()) - return subquery - - bracket = parse_bracket and self._parse_bracket(None) - bracket = self.expression(exp.Table, this=bracket) if bracket else None - - rows_from = self._match_text_seq("ROWS", "FROM") and self._parse_wrapped_csv( - self._parse_table - ) - rows_from = self.expression(exp.Table, rows_from=rows_from) if rows_from else None - - only = self._match(TokenType.ONLY) - - this = t.cast( - exp.Expression, - bracket - or rows_from - or self._parse_bracket( - self._parse_table_parts(schema=schema, is_db_reference=is_db_reference) - ), - ) - - if only: - this.set("only", only) - - # Postgres supports a wildcard (table) suffix operator, which is a no-op in this context - self._match_text_seq("*") - - parse_partition = parse_partition or self.SUPPORTS_PARTITION_SELECTION - if parse_partition and self._match(TokenType.PARTITION, advance=False): - this.set("partition", self._parse_partition()) - - if schema: - return self._parse_schema(this=this) - - version = self._parse_version() - - if version: - this.set("version", version) - - if self.dialect.ALIAS_POST_TABLESAMPLE: - this.set("sample", self._parse_table_sample()) - - alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS) - if alias: - this.set("alias", alias) - - if isinstance(this, exp.Table) and self._match_text_seq("AT"): - return self.expression( - exp.AtIndex, this=this.to_column(copy=False), expression=self._parse_id_var() - ) - - this.set("hints", self._parse_table_hints()) - - if not this.args.get("pivots"): - this.set("pivots", self._parse_pivots()) - - if not self.dialect.ALIAS_POST_TABLESAMPLE: - this.set("sample", self._parse_table_sample()) - - if joins: - for join in self._parse_joins(): - this.append("joins", join) - - if self._match_pair(TokenType.WITH, TokenType.ORDINALITY): - this.set("ordinality", True) - this.set("alias", self._parse_table_alias()) - - return this - - def _parse_version(self) -> t.Optional[exp.Version]: - if self._match(TokenType.TIMESTAMP_SNAPSHOT): - this = "TIMESTAMP" - elif self._match(TokenType.VERSION_SNAPSHOT): - this = "VERSION" - else: - return None - - if self._match_set((TokenType.FROM, TokenType.BETWEEN)): - kind = self._prev.text.upper() - start = self._parse_bitwise() - self._match_texts(("TO", "AND")) - end = self._parse_bitwise() - expression: t.Optional[exp.Expression] = self.expression( - exp.Tuple, expressions=[start, end] - ) - elif self._match_text_seq("CONTAINED", "IN"): - kind = "CONTAINED IN" - expression = self.expression( - exp.Tuple, expressions=self._parse_wrapped_csv(self._parse_bitwise) - ) - elif self._match(TokenType.ALL): - kind = "ALL" - expression = None - else: - self._match_text_seq("AS", "OF") - kind = "AS OF" - expression = self._parse_type() - - return self.expression(exp.Version, this=this, expression=expression, kind=kind) - - def _parse_historical_data(self) -> t.Optional[exp.HistoricalData]: - # https://docs.snowflake.com/en/sql-reference/constructs/at-before - index = self._index - historical_data = None - if self._match_texts(self.HISTORICAL_DATA_PREFIX): - this = self._prev.text.upper() - kind = ( - self._match(TokenType.L_PAREN) - and self._match_texts(self.HISTORICAL_DATA_KIND) - and self._prev.text.upper() - ) - expression = self._match(TokenType.FARROW) and self._parse_bitwise() - - if expression: - self._match_r_paren() - historical_data = self.expression( - exp.HistoricalData, this=this, kind=kind, expression=expression - ) - else: - self._retreat(index) - - return historical_data - - def _parse_changes(self) -> t.Optional[exp.Changes]: - if not self._match_text_seq("CHANGES", "(", "INFORMATION", "=>"): - return None - - information = self._parse_var(any_token=True) - self._match_r_paren() - - return self.expression( - exp.Changes, - information=information, - at_before=self._parse_historical_data(), - end=self._parse_historical_data(), - ) - - def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: - if not self._match(TokenType.UNNEST): - return None - - expressions = self._parse_wrapped_csv(self._parse_equality) - offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY) - - alias = self._parse_table_alias() if with_alias else None - - if alias: - if self.dialect.UNNEST_COLUMN_ONLY: - if alias.args.get("columns"): - self.raise_error("Unexpected extra column alias in unnest.") - - alias.set("columns", [alias.this]) - alias.set("this", None) - - columns = alias.args.get("columns") or [] - if offset and len(expressions) < len(columns): - offset = columns.pop() - - if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET): - self._match(TokenType.ALIAS) - offset = self._parse_id_var( - any_token=False, tokens=self.UNNEST_OFFSET_ALIAS_TOKENS - ) or exp.to_identifier("offset") - - return self.expression(exp.Unnest, expressions=expressions, alias=alias, offset=offset) - - def _parse_derived_table_values(self) -> t.Optional[exp.Values]: - is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES) - if not is_derived and not ( - # ClickHouse's `FORMAT Values` is equivalent to `VALUES` - self._match_text_seq("VALUES") or self._match_text_seq("FORMAT", "VALUES") - ): - return None - - expressions = self._parse_csv(self._parse_value) - alias = self._parse_table_alias() - - if is_derived: - self._match_r_paren() - - return self.expression( - exp.Values, expressions=expressions, alias=alias or self._parse_table_alias() - ) - - def _parse_table_sample(self, as_modifier: bool = False) -> t.Optional[exp.TableSample]: - if not self._match(TokenType.TABLE_SAMPLE) and not ( - as_modifier and self._match_text_seq("USING", "SAMPLE") - ): - return None - - bucket_numerator = None - bucket_denominator = None - bucket_field = None - percent = None - size = None - seed = None - - method = self._parse_var(tokens=(TokenType.ROW,), upper=True) - matched_l_paren = self._match(TokenType.L_PAREN) - - if self.TABLESAMPLE_CSV: - num = None - expressions = self._parse_csv(self._parse_primary) - else: - expressions = None - num = ( - self._parse_factor() - if self._match(TokenType.NUMBER, advance=False) - else self._parse_primary() or self._parse_placeholder() - ) - - if self._match_text_seq("BUCKET"): - bucket_numerator = self._parse_number() - self._match_text_seq("OUT", "OF") - bucket_denominator = bucket_denominator = self._parse_number() - self._match(TokenType.ON) - bucket_field = self._parse_field() - elif self._match_set((TokenType.PERCENT, TokenType.MOD)): - percent = num - elif self._match(TokenType.ROWS) or not self.dialect.TABLESAMPLE_SIZE_IS_PERCENT: - size = num - else: - percent = num - - if matched_l_paren: - self._match_r_paren() - - if self._match(TokenType.L_PAREN): - method = self._parse_var(upper=True) - seed = self._match(TokenType.COMMA) and self._parse_number() - self._match_r_paren() - elif self._match_texts(("SEED", "REPEATABLE")): - seed = self._parse_wrapped(self._parse_number) - - if not method and self.DEFAULT_SAMPLING_METHOD: - method = exp.var(self.DEFAULT_SAMPLING_METHOD) - - return self.expression( - exp.TableSample, - expressions=expressions, - method=method, - bucket_numerator=bucket_numerator, - bucket_denominator=bucket_denominator, - bucket_field=bucket_field, - percent=percent, - size=size, - seed=seed, - ) - - def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]: - return list(iter(self._parse_pivot, None)) or None - - def _parse_joins(self) -> t.Iterator[exp.Join]: - return iter(self._parse_join, None) - - def _parse_unpivot_columns(self) -> t.Optional[exp.UnpivotColumns]: - if not self._match(TokenType.INTO): - return None - - return self.expression( - exp.UnpivotColumns, - this=self._match_text_seq("NAME") and self._parse_column(), - expressions=self._match_text_seq("VALUE") and self._parse_csv(self._parse_column), - ) - - # https://duckdb.org/docs/sql/statements/pivot - def _parse_simplified_pivot(self, is_unpivot: t.Optional[bool] = None) -> exp.Pivot: - def _parse_on() -> t.Optional[exp.Expression]: - this = self._parse_bitwise() - - if self._match(TokenType.IN): - # PIVOT ... ON col IN (row_val1, row_val2) - return self._parse_in(this) - if self._match(TokenType.ALIAS, advance=False): - # UNPIVOT ... ON (col1, col2, col3) AS row_val - return self._parse_alias(this) - - return this - - this = self._parse_table() - expressions = self._match(TokenType.ON) and self._parse_csv(_parse_on) - into = self._parse_unpivot_columns() - using = self._match(TokenType.USING) and self._parse_csv( - lambda: self._parse_alias(self._parse_function()) - ) - group = self._parse_group() - - return self.expression( - exp.Pivot, - this=this, - expressions=expressions, - using=using, - group=group, - unpivot=is_unpivot, - into=into, - ) - - def _parse_pivot_in(self) -> exp.In: - def _parse_aliased_expression() -> t.Optional[exp.Expression]: - this = self._parse_select_or_expression() - - self._match(TokenType.ALIAS) - alias = self._parse_bitwise() - if alias: - if isinstance(alias, exp.Column) and not alias.db: - alias = alias.this - return self.expression(exp.PivotAlias, this=this, alias=alias) - - return this - - value = self._parse_column() - - if not self._match_pair(TokenType.IN, TokenType.L_PAREN): - self.raise_error("Expecting IN (") - - if self._match(TokenType.ANY): - exprs: t.List[exp.Expression] = ensure_list(exp.PivotAny(this=self._parse_order())) - else: - exprs = self._parse_csv(_parse_aliased_expression) - - self._match_r_paren() - return self.expression(exp.In, this=value, expressions=exprs) - - def _parse_pivot(self) -> t.Optional[exp.Pivot]: - index = self._index - include_nulls = None - - if self._match(TokenType.PIVOT): - unpivot = False - elif self._match(TokenType.UNPIVOT): - unpivot = True - - # https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot.html#syntax - if self._match_text_seq("INCLUDE", "NULLS"): - include_nulls = True - elif self._match_text_seq("EXCLUDE", "NULLS"): - include_nulls = False - else: - return None - - expressions = [] - - if not self._match(TokenType.L_PAREN): - self._retreat(index) - return None - - if unpivot: - expressions = self._parse_csv(self._parse_column) - else: - expressions = self._parse_csv(lambda: self._parse_alias(self._parse_function())) - - if not expressions: - self.raise_error("Failed to parse PIVOT's aggregation list") - - if not self._match(TokenType.FOR): - self.raise_error("Expecting FOR") - - fields = [] - while True: - field = self._try_parse(self._parse_pivot_in) - if not field: - break - fields.append(field) - - default_on_null = self._match_text_seq("DEFAULT", "ON", "NULL") and self._parse_wrapped( - self._parse_bitwise - ) - - group = self._parse_group() - - self._match_r_paren() - - pivot = self.expression( - exp.Pivot, - expressions=expressions, - fields=fields, - unpivot=unpivot, - include_nulls=include_nulls, - default_on_null=default_on_null, - group=group, - ) - - if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False): - pivot.set("alias", self._parse_table_alias()) - - if not unpivot: - names = self._pivot_column_names(t.cast(t.List[exp.Expression], expressions)) - - columns: t.List[exp.Expression] = [] - all_fields = [] - for pivot_field in pivot.fields: - pivot_field_expressions = pivot_field.expressions - - # The `PivotAny` expression corresponds to `ANY ORDER BY `; we can't infer in this case. - if isinstance(seq_get(pivot_field_expressions, 0), exp.PivotAny): - continue - - all_fields.append( - [ - fld.sql() if self.IDENTIFY_PIVOT_STRINGS else fld.alias_or_name - for fld in pivot_field_expressions - ] - ) - - if all_fields: - if names: - all_fields.append(names) - - # Generate all possible combinations of the pivot columns - # e.g PIVOT(sum(...) as total FOR year IN (2000, 2010) FOR country IN ('NL', 'US')) - # generates the product between [[2000, 2010], ['NL', 'US'], ['total']] - for fld_parts_tuple in itertools.product(*all_fields): - fld_parts = list(fld_parts_tuple) - - if names and self.PREFIXED_PIVOT_COLUMNS: - # Move the "name" to the front of the list - fld_parts.insert(0, fld_parts.pop(-1)) - - columns.append(exp.to_identifier("_".join(fld_parts))) - - pivot.set("columns", columns) - - return pivot - - def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: - return [agg.alias for agg in aggregations if agg.alias] - - def _parse_prewhere(self, skip_where_token: bool = False) -> t.Optional[exp.PreWhere]: - if not skip_where_token and not self._match(TokenType.PREWHERE): - return None - - return self.expression( - exp.PreWhere, comments=self._prev_comments, this=self._parse_assignment() - ) - - def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Where]: - if not skip_where_token and not self._match(TokenType.WHERE): - return None - - return self.expression( - exp.Where, comments=self._prev_comments, this=self._parse_assignment() - ) - - def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Group]: - if not skip_group_by_token and not self._match(TokenType.GROUP_BY): - return None - - elements: t.Dict[str, t.Any] = defaultdict(list) - - if self._match(TokenType.ALL): - elements["all"] = True - elif self._match(TokenType.DISTINCT): - elements["all"] = False - - while True: - index = self._index - - elements["expressions"].extend( - self._parse_csv( - lambda: None - if self._match_set((TokenType.CUBE, TokenType.ROLLUP), advance=False) - else self._parse_assignment() - ) - ) - - before_with_index = self._index - with_prefix = self._match(TokenType.WITH) - - if self._match(TokenType.ROLLUP): - elements["rollup"].append( - self._parse_cube_or_rollup(exp.Rollup, with_prefix=with_prefix) - ) - elif self._match(TokenType.CUBE): - elements["cube"].append( - self._parse_cube_or_rollup(exp.Cube, with_prefix=with_prefix) - ) - elif self._match(TokenType.GROUPING_SETS): - elements["grouping_sets"].append( - self.expression( - exp.GroupingSets, - expressions=self._parse_wrapped_csv(self._parse_grouping_set), - ) - ) - elif self._match_text_seq("TOTALS"): - elements["totals"] = True # type: ignore - - if before_with_index <= self._index <= before_with_index + 1: - self._retreat(before_with_index) - break - - if index == self._index: - break - - return self.expression(exp.Group, **elements) # type: ignore - - def _parse_cube_or_rollup(self, kind: t.Type[E], with_prefix: bool = False) -> E: - return self.expression( - kind, expressions=[] if with_prefix else self._parse_wrapped_csv(self._parse_column) - ) - - def _parse_grouping_set(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.L_PAREN): - grouping_set = self._parse_csv(self._parse_column) - self._match_r_paren() - return self.expression(exp.Tuple, expressions=grouping_set) - - return self._parse_column() - - def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Having]: - if not skip_having_token and not self._match(TokenType.HAVING): - return None - return self.expression(exp.Having, this=self._parse_assignment()) - - def _parse_qualify(self) -> t.Optional[exp.Qualify]: - if not self._match(TokenType.QUALIFY): - return None - return self.expression(exp.Qualify, this=self._parse_assignment()) - - def _parse_connect_with_prior(self) -> t.Optional[exp.Expression]: - self.NO_PAREN_FUNCTION_PARSERS["PRIOR"] = lambda self: self.expression( - exp.Prior, this=self._parse_bitwise() - ) - connect = self._parse_assignment() - self.NO_PAREN_FUNCTION_PARSERS.pop("PRIOR") - return connect - - def _parse_connect(self, skip_start_token: bool = False) -> t.Optional[exp.Connect]: - if skip_start_token: - start = None - elif self._match(TokenType.START_WITH): - start = self._parse_assignment() - else: - return None - - self._match(TokenType.CONNECT_BY) - nocycle = self._match_text_seq("NOCYCLE") - connect = self._parse_connect_with_prior() - - if not start and self._match(TokenType.START_WITH): - start = self._parse_assignment() - - return self.expression(exp.Connect, start=start, connect=connect, nocycle=nocycle) - - def _parse_name_as_expression(self) -> t.Optional[exp.Expression]: - this = self._parse_id_var(any_token=True) - if self._match(TokenType.ALIAS): - this = self.expression(exp.Alias, alias=this, this=self._parse_assignment()) - return this - - def _parse_interpolate(self) -> t.Optional[t.List[exp.Expression]]: - if self._match_text_seq("INTERPOLATE"): - return self._parse_wrapped_csv(self._parse_name_as_expression) - return None - - def _parse_order( - self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False - ) -> t.Optional[exp.Expression]: - siblings = None - if not skip_order_token and not self._match(TokenType.ORDER_BY): - if not self._match(TokenType.ORDER_SIBLINGS_BY): - return this - - siblings = True - - return self.expression( - exp.Order, - this=this, - expressions=self._parse_csv(self._parse_ordered), - siblings=siblings, - ) - - def _parse_sort(self, exp_class: t.Type[E], token: TokenType) -> t.Optional[E]: - if not self._match(token): - return None - return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) - - def _parse_ordered( - self, parse_method: t.Optional[t.Callable] = None - ) -> t.Optional[exp.Ordered]: - this = parse_method() if parse_method else self._parse_assignment() - if not this: - return None - - if this.name.upper() == "ALL" and self.dialect.SUPPORTS_ORDER_BY_ALL: - this = exp.var("ALL") - - asc = self._match(TokenType.ASC) - desc = self._match(TokenType.DESC) or (asc and False) - - is_nulls_first = self._match_text_seq("NULLS", "FIRST") - is_nulls_last = self._match_text_seq("NULLS", "LAST") - - nulls_first = is_nulls_first or False - explicitly_null_ordered = is_nulls_first or is_nulls_last - - if ( - not explicitly_null_ordered - and ( - (not desc and self.dialect.NULL_ORDERING == "nulls_are_small") - or (desc and self.dialect.NULL_ORDERING != "nulls_are_small") - ) - and self.dialect.NULL_ORDERING != "nulls_are_last" - ): - nulls_first = True - - if self._match_text_seq("WITH", "FILL"): - with_fill = self.expression( - exp.WithFill, - **{ # type: ignore - "from": self._match(TokenType.FROM) and self._parse_bitwise(), - "to": self._match_text_seq("TO") and self._parse_bitwise(), - "step": self._match_text_seq("STEP") and self._parse_bitwise(), - "interpolate": self._parse_interpolate(), - }, - ) - else: - with_fill = None - - return self.expression( - exp.Ordered, this=this, desc=desc, nulls_first=nulls_first, with_fill=with_fill - ) - - def _parse_limit_options(self) -> exp.LimitOptions: - percent = self._match(TokenType.PERCENT) - rows = self._match_set((TokenType.ROW, TokenType.ROWS)) - self._match_text_seq("ONLY") - with_ties = self._match_text_seq("WITH", "TIES") - return self.expression(exp.LimitOptions, percent=percent, rows=rows, with_ties=with_ties) - - def _parse_limit( - self, - this: t.Optional[exp.Expression] = None, - top: bool = False, - skip_limit_token: bool = False, - ) -> t.Optional[exp.Expression]: - if skip_limit_token or self._match(TokenType.TOP if top else TokenType.LIMIT): - comments = self._prev_comments - if top: - limit_paren = self._match(TokenType.L_PAREN) - expression = self._parse_term() if limit_paren else self._parse_number() - - if limit_paren: - self._match_r_paren() - - limit_options = self._parse_limit_options() - else: - limit_options = None - expression = self._parse_term() - - if self._match(TokenType.COMMA): - offset = expression - expression = self._parse_term() - else: - offset = None - - limit_exp = self.expression( - exp.Limit, - this=this, - expression=expression, - offset=offset, - comments=comments, - limit_options=limit_options, - expressions=self._parse_limit_by(), - ) - - return limit_exp - - if self._match(TokenType.FETCH): - direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) - direction = self._prev.text.upper() if direction else "FIRST" - - count = self._parse_field(tokens=self.FETCH_TOKENS) - - return self.expression( - exp.Fetch, - direction=direction, - count=count, - limit_options=self._parse_limit_options(), - ) - - return this - - def _parse_offset(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: - if not self._match(TokenType.OFFSET): - return this - - count = self._parse_term() - self._match_set((TokenType.ROW, TokenType.ROWS)) - - return self.expression( - exp.Offset, this=this, expression=count, expressions=self._parse_limit_by() - ) - - def _can_parse_limit_or_offset(self) -> bool: - if not self._match_set(self.AMBIGUOUS_ALIAS_TOKENS, advance=False): - return False - - index = self._index - result = bool( - self._try_parse(self._parse_limit, retreat=True) - or self._try_parse(self._parse_offset, retreat=True) - ) - self._retreat(index) - return result - - def _parse_limit_by(self) -> t.Optional[t.List[exp.Expression]]: - return self._match_text_seq("BY") and self._parse_csv(self._parse_bitwise) - - def _parse_locks(self) -> t.List[exp.Lock]: - locks = [] - while True: - if self._match_text_seq("FOR", "UPDATE"): - update = True - elif self._match_text_seq("FOR", "SHARE") or self._match_text_seq( - "LOCK", "IN", "SHARE", "MODE" - ): - update = False - else: - break - - expressions = None - if self._match_text_seq("OF"): - expressions = self._parse_csv(lambda: self._parse_table(schema=True)) - - wait: t.Optional[bool | exp.Expression] = None - if self._match_text_seq("NOWAIT"): - wait = True - elif self._match_text_seq("WAIT"): - wait = self._parse_primary() - elif self._match_text_seq("SKIP", "LOCKED"): - wait = False - - locks.append( - self.expression(exp.Lock, update=update, expressions=expressions, wait=wait) - ) - - return locks - - def parse_set_operation(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - start = self._index - _, side_token, kind_token = self._parse_join_parts() - - side = side_token.text if side_token else None - kind = kind_token.text if kind_token else None - - if not self._match_set(self.SET_OPERATIONS): - self._retreat(start) - return None - - token_type = self._prev.token_type - - if token_type == TokenType.UNION: - operation: t.Type[exp.SetOperation] = exp.Union - elif token_type == TokenType.EXCEPT: - operation = exp.Except - else: - operation = exp.Intersect - - comments = self._prev.comments - - if self._match(TokenType.DISTINCT): - distinct: t.Optional[bool] = True - elif self._match(TokenType.ALL): - distinct = False - else: - distinct = self.dialect.SET_OP_DISTINCT_BY_DEFAULT[operation] - if distinct is None: - self.raise_error(f"Expected DISTINCT or ALL for {operation.__name__}") - - by_name = self._match_text_seq("BY", "NAME") or self._match_text_seq( - "STRICT", "CORRESPONDING" - ) - if self._match_text_seq("CORRESPONDING"): - by_name = True - if not side and not kind: - kind = "INNER" - - on_column_list = None - if by_name and self._match_texts(("ON", "BY")): - on_column_list = self._parse_wrapped_csv(self._parse_column) - - expression = self._parse_select(nested=True, parse_set_operation=False) - - return self.expression( - operation, - comments=comments, - this=this, - distinct=distinct, - by_name=by_name, - expression=expression, - side=side, - kind=kind, - on=on_column_list, - ) - - def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - while this: - setop = self.parse_set_operation(this) - if not setop: - break - this = setop - - if isinstance(this, exp.SetOperation) and self.MODIFIERS_ATTACHED_TO_SET_OP: - expression = this.expression - - if expression: - for arg in self.SET_OP_MODIFIERS: - expr = expression.args.get(arg) - if expr: - this.set(arg, expr.pop()) - - return this - - def _parse_expression(self) -> t.Optional[exp.Expression]: - return self._parse_alias(self._parse_assignment()) - - def _parse_assignment(self) -> t.Optional[exp.Expression]: - this = self._parse_disjunction() - if not this and self._next and self._next.token_type in self.ASSIGNMENT: - # This allows us to parse := - this = exp.column( - t.cast(str, self._advance_any(ignore_reserved=True) and self._prev.text) - ) - - while self._match_set(self.ASSIGNMENT): - if isinstance(this, exp.Column) and len(this.parts) == 1: - this = this.this - - this = self.expression( - self.ASSIGNMENT[self._prev.token_type], - this=this, - comments=self._prev_comments, - expression=self._parse_assignment(), - ) - - return this - - def _parse_disjunction(self) -> t.Optional[exp.Expression]: - return self._parse_tokens(self._parse_conjunction, self.DISJUNCTION) - - def _parse_conjunction(self) -> t.Optional[exp.Expression]: - return self._parse_tokens(self._parse_equality, self.CONJUNCTION) - - def _parse_equality(self) -> t.Optional[exp.Expression]: - return self._parse_tokens(self._parse_comparison, self.EQUALITY) - - def _parse_comparison(self) -> t.Optional[exp.Expression]: - return self._parse_tokens(self._parse_range, self.COMPARISON) - - def _parse_range(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: - this = this or self._parse_bitwise() - negate = self._match(TokenType.NOT) - - if self._match_set(self.RANGE_PARSERS): - expression = self.RANGE_PARSERS[self._prev.token_type](self, this) - if not expression: - return this - - this = expression - elif self._match(TokenType.ISNULL): - this = self.expression(exp.Is, this=this, expression=exp.Null()) - - # Postgres supports ISNULL and NOTNULL for conditions. - # https://blog.andreiavram.ro/postgresql-null-composite-type/ - if self._match(TokenType.NOTNULL): - this = self.expression(exp.Is, this=this, expression=exp.Null()) - this = self.expression(exp.Not, this=this) - - if negate: - this = self._negate_range(this) - - if self._match(TokenType.IS): - this = self._parse_is(this) - - return this - - def _negate_range(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: - if not this: - return this - - return self.expression(exp.Not, this=this) - - def _parse_is(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - index = self._index - 1 - negate = self._match(TokenType.NOT) - - if self._match_text_seq("DISTINCT", "FROM"): - klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ - return self.expression(klass, this=this, expression=self._parse_bitwise()) - - if self._match(TokenType.JSON): - kind = self._match_texts(self.IS_JSON_PREDICATE_KIND) and self._prev.text.upper() - - if self._match_text_seq("WITH"): - _with = True - elif self._match_text_seq("WITHOUT"): - _with = False - else: - _with = None - - unique = self._match(TokenType.UNIQUE) - self._match_text_seq("KEYS") - expression: t.Optional[exp.Expression] = self.expression( - exp.JSON, **{"this": kind, "with": _with, "unique": unique} - ) - else: - expression = self._parse_primary() or self._parse_null() - if not expression: - self._retreat(index) - return None - - this = self.expression(exp.Is, this=this, expression=expression) - return self.expression(exp.Not, this=this) if negate else this - - def _parse_in(self, this: t.Optional[exp.Expression], alias: bool = False) -> exp.In: - unnest = self._parse_unnest(with_alias=False) - if unnest: - this = self.expression(exp.In, this=this, unnest=unnest) - elif self._match_set((TokenType.L_PAREN, TokenType.L_BRACKET)): - matched_l_paren = self._prev.token_type == TokenType.L_PAREN - expressions = self._parse_csv(lambda: self._parse_select_or_expression(alias=alias)) - - if len(expressions) == 1 and isinstance(expressions[0], exp.Query): - this = self.expression(exp.In, this=this, query=expressions[0].subquery(copy=False)) - else: - this = self.expression(exp.In, this=this, expressions=expressions) - - if matched_l_paren: - self._match_r_paren(this) - elif not self._match(TokenType.R_BRACKET, expression=this): - self.raise_error("Expecting ]") - else: - this = self.expression(exp.In, this=this, field=self._parse_column()) - - return this - - def _parse_between(self, this: t.Optional[exp.Expression]) -> exp.Between: - low = self._parse_bitwise() - self._match(TokenType.AND) - high = self._parse_bitwise() - return self.expression(exp.Between, this=this, low=low, high=high) - - def _parse_escape(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - if not self._match(TokenType.ESCAPE): - return this - return self.expression(exp.Escape, this=this, expression=self._parse_string()) - - def _parse_interval(self, match_interval: bool = True) -> t.Optional[exp.Add | exp.Interval]: - index = self._index - - if not self._match(TokenType.INTERVAL) and match_interval: - return None - - if self._match(TokenType.STRING, advance=False): - this = self._parse_primary() - else: - this = self._parse_term() - - if not this or ( - isinstance(this, exp.Column) - and not this.table - and not this.this.quoted - and this.name.upper() == "IS" - ): - self._retreat(index) - return None - - unit = self._parse_function() or ( - not self._match(TokenType.ALIAS, advance=False) - and self._parse_var(any_token=True, upper=True) - ) - - # Most dialects support, e.g., the form INTERVAL '5' day, thus we try to parse - # each INTERVAL expression into this canonical form so it's easy to transpile - if this and this.is_number: - this = exp.Literal.string(this.to_py()) - elif this and this.is_string: - parts = exp.INTERVAL_STRING_RE.findall(this.name) - if parts and unit: - # Unconsume the eagerly-parsed unit, since the real unit was part of the string - unit = None - self._retreat(self._index - 1) - - if len(parts) == 1: - this = exp.Literal.string(parts[0][0]) - unit = self.expression(exp.Var, this=parts[0][1].upper()) - if self.INTERVAL_SPANS and self._match_text_seq("TO"): - unit = self.expression( - exp.IntervalSpan, this=unit, expression=self._parse_var(any_token=True, upper=True) - ) - - interval = self.expression(exp.Interval, this=this, unit=unit) - - index = self._index - self._match(TokenType.PLUS) - - # Convert INTERVAL 'val_1' unit_1 [+] ... [+] 'val_n' unit_n into a sum of intervals - if self._match_set((TokenType.STRING, TokenType.NUMBER), advance=False): - return self.expression( - exp.Add, this=interval, expression=self._parse_interval(match_interval=False) - ) - - self._retreat(index) - return interval - - def _parse_bitwise(self) -> t.Optional[exp.Expression]: - this = self._parse_term() - - while True: - if self._match_set(self.BITWISE): - this = self.expression( - self.BITWISE[self._prev.token_type], - this=this, - expression=self._parse_term(), - ) - elif self.dialect.DPIPE_IS_STRING_CONCAT and self._match(TokenType.DPIPE): - this = self.expression( - exp.DPipe, - this=this, - expression=self._parse_term(), - safe=not self.dialect.STRICT_STRING_CONCAT, - ) - elif self._match(TokenType.DQMARK): - this = self.expression( - exp.Coalesce, this=this, expressions=ensure_list(self._parse_term()) - ) - elif self._match_pair(TokenType.LT, TokenType.LT): - this = self.expression( - exp.BitwiseLeftShift, this=this, expression=self._parse_term() - ) - elif self._match_pair(TokenType.GT, TokenType.GT): - this = self.expression( - exp.BitwiseRightShift, this=this, expression=self._parse_term() - ) - else: - break - - return this - - def _parse_term(self) -> t.Optional[exp.Expression]: - this = self._parse_factor() - - while self._match_set(self.TERM): - klass = self.TERM[self._prev.token_type] - comments = self._prev_comments - expression = self._parse_factor() - - this = self.expression(klass, this=this, comments=comments, expression=expression) - - if isinstance(this, exp.Collate): - expr = this.expression - - # Preserve collations such as pg_catalog."default" (Postgres) as columns, otherwise - # fallback to Identifier / Var - if isinstance(expr, exp.Column) and len(expr.parts) == 1: - ident = expr.this - if isinstance(ident, exp.Identifier): - this.set("expression", ident if ident.quoted else exp.var(ident.name)) - - return this - - def _parse_factor(self) -> t.Optional[exp.Expression]: - parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary - this = parse_method() - - while self._match_set(self.FACTOR): - klass = self.FACTOR[self._prev.token_type] - comments = self._prev_comments - expression = parse_method() - - if not expression and klass is exp.IntDiv and self._prev.text.isalpha(): - self._retreat(self._index - 1) - return this - - this = self.expression(klass, this=this, comments=comments, expression=expression) - - if isinstance(this, exp.Div): - this.args["typed"] = self.dialect.TYPED_DIVISION - this.args["safe"] = self.dialect.SAFE_DIVISION - - return this - - def _parse_exponent(self) -> t.Optional[exp.Expression]: - return self._parse_tokens(self._parse_unary, self.EXPONENT) - - def _parse_unary(self) -> t.Optional[exp.Expression]: - if self._match_set(self.UNARY_PARSERS): - return self.UNARY_PARSERS[self._prev.token_type](self) - return self._parse_at_time_zone(self._parse_type()) - - def _parse_type( - self, parse_interval: bool = True, fallback_to_identifier: bool = False - ) -> t.Optional[exp.Expression]: - interval = parse_interval and self._parse_interval() - if interval: - return interval - - index = self._index - data_type = self._parse_types(check_func=True, allow_identifiers=False) - - # parse_types() returns a Cast if we parsed BQ's inline constructor () e.g. - # STRUCT(1, 'foo'), which is canonicalized to CAST( AS ) - if isinstance(data_type, exp.Cast): - # This constructor can contain ops directly after it, for instance struct unnesting: - # STRUCT(1, 'foo').* --> CAST(STRUCT(1, 'foo') AS STRUCT 1: - self._retreat(index2) - return self._parse_column_ops(data_type) - - self._retreat(index) - - if fallback_to_identifier: - return self._parse_id_var() - - this = self._parse_column() - return this and self._parse_column_ops(this) - - def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]: - this = self._parse_type() - if not this: - return None - - if isinstance(this, exp.Column) and not this.table: - this = exp.var(this.name.upper()) - - return self.expression( - exp.DataTypeParam, this=this, expression=self._parse_var(any_token=True) - ) - - def _parse_user_defined_type(self, identifier: exp.Identifier) -> t.Optional[exp.Expression]: - type_name = identifier.name - - while self._match(TokenType.DOT): - type_name = f"{type_name}.{self._advance_any() and self._prev.text}" - - return exp.DataType.build(type_name, udt=True) - - def _parse_types( - self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True - ) -> t.Optional[exp.Expression]: - index = self._index - - this: t.Optional[exp.Expression] = None - prefix = self._match_text_seq("SYSUDTLIB", ".") - - if not self._match_set(self.TYPE_TOKENS): - identifier = allow_identifiers and self._parse_id_var( - any_token=False, tokens=(TokenType.VAR,) - ) - if isinstance(identifier, exp.Identifier): - tokens = self.dialect.tokenize(identifier.sql(dialect=self.dialect)) - - if len(tokens) != 1: - self.raise_error("Unexpected identifier", self._prev) - - if tokens[0].token_type in self.TYPE_TOKENS: - self._prev = tokens[0] - elif self.dialect.SUPPORTS_USER_DEFINED_TYPES: - this = self._parse_user_defined_type(identifier) - else: - self._retreat(self._index - 1) - return None - else: - return None - - type_token = self._prev.token_type - - if type_token == TokenType.PSEUDO_TYPE: - return self.expression(exp.PseudoType, this=self._prev.text.upper()) - - if type_token == TokenType.OBJECT_IDENTIFIER: - return self.expression(exp.ObjectIdentifier, this=self._prev.text.upper()) - - # https://materialize.com/docs/sql/types/map/ - if type_token == TokenType.MAP and self._match(TokenType.L_BRACKET): - key_type = self._parse_types( - check_func=check_func, schema=schema, allow_identifiers=allow_identifiers - ) - if not self._match(TokenType.FARROW): - self._retreat(index) - return None - - value_type = self._parse_types( - check_func=check_func, schema=schema, allow_identifiers=allow_identifiers - ) - if not self._match(TokenType.R_BRACKET): - self._retreat(index) - return None - - return exp.DataType( - this=exp.DataType.Type.MAP, - expressions=[key_type, value_type], - nested=True, - prefix=prefix, - ) - - nested = type_token in self.NESTED_TYPE_TOKENS - is_struct = type_token in self.STRUCT_TYPE_TOKENS - is_aggregate = type_token in self.AGGREGATE_TYPE_TOKENS - expressions = None - maybe_func = False - - if self._match(TokenType.L_PAREN): - if is_struct: - expressions = self._parse_csv(lambda: self._parse_struct_types(type_required=True)) - elif nested: - expressions = self._parse_csv( - lambda: self._parse_types( - check_func=check_func, schema=schema, allow_identifiers=allow_identifiers - ) - ) - if type_token == TokenType.NULLABLE and len(expressions) == 1: - this = expressions[0] - this.set("nullable", True) - self._match_r_paren() - return this - elif type_token in self.ENUM_TYPE_TOKENS: - expressions = self._parse_csv(self._parse_equality) - elif is_aggregate: - func_or_ident = self._parse_function(anonymous=True) or self._parse_id_var( - any_token=False, tokens=(TokenType.VAR, TokenType.ANY) - ) - if not func_or_ident: - return None - expressions = [func_or_ident] - if self._match(TokenType.COMMA): - expressions.extend( - self._parse_csv( - lambda: self._parse_types( - check_func=check_func, - schema=schema, - allow_identifiers=allow_identifiers, - ) - ) - ) - else: - expressions = self._parse_csv(self._parse_type_size) - - # https://docs.snowflake.com/en/sql-reference/data-types-vector - if type_token == TokenType.VECTOR and len(expressions) == 2: - expressions[0] = exp.DataType.build(expressions[0].name, dialect=self.dialect) - - if not expressions or not self._match(TokenType.R_PAREN): - self._retreat(index) - return None - - maybe_func = True - - values: t.Optional[t.List[exp.Expression]] = None - - if nested and self._match(TokenType.LT): - if is_struct: - expressions = self._parse_csv(lambda: self._parse_struct_types(type_required=True)) - else: - expressions = self._parse_csv( - lambda: self._parse_types( - check_func=check_func, schema=schema, allow_identifiers=allow_identifiers - ) - ) - - if not self._match(TokenType.GT): - self.raise_error("Expecting >") - - if self._match_set((TokenType.L_BRACKET, TokenType.L_PAREN)): - values = self._parse_csv(self._parse_assignment) - if not values and is_struct: - values = None - self._retreat(self._index - 1) - else: - self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN)) - - if type_token in self.TIMESTAMPS: - if self._match_text_seq("WITH", "TIME", "ZONE"): - maybe_func = False - tz_type = ( - exp.DataType.Type.TIMETZ - if type_token in self.TIMES - else exp.DataType.Type.TIMESTAMPTZ - ) - this = exp.DataType(this=tz_type, expressions=expressions) - elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"): - maybe_func = False - this = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) - elif self._match_text_seq("WITHOUT", "TIME", "ZONE"): - maybe_func = False - elif type_token == TokenType.INTERVAL: - unit = self._parse_var(upper=True) - if unit: - if self._match_text_seq("TO"): - unit = exp.IntervalSpan(this=unit, expression=self._parse_var(upper=True)) - - this = self.expression(exp.DataType, this=self.expression(exp.Interval, unit=unit)) - else: - this = self.expression(exp.DataType, this=exp.DataType.Type.INTERVAL) - elif type_token == TokenType.VOID: - this = exp.DataType(this=exp.DataType.Type.NULL) - - if maybe_func and check_func: - index2 = self._index - peek = self._parse_string() - - if not peek: - self._retreat(index) - return None - - self._retreat(index2) - - if not this: - if self._match_text_seq("UNSIGNED"): - unsigned_type_token = self.SIGNED_TO_UNSIGNED_TYPE_TOKEN.get(type_token) - if not unsigned_type_token: - self.raise_error(f"Cannot convert {type_token.value} to unsigned.") - - type_token = unsigned_type_token or type_token - - this = exp.DataType( - this=exp.DataType.Type[type_token.value], - expressions=expressions, - nested=nested, - prefix=prefix, - ) - - # Empty arrays/structs are allowed - if values is not None: - cls = exp.Struct if is_struct else exp.Array - this = exp.cast(cls(expressions=values), this, copy=False) - - elif expressions: - this.set("expressions", expressions) - - # https://materialize.com/docs/sql/types/list/#type-name - while self._match(TokenType.LIST): - this = exp.DataType(this=exp.DataType.Type.LIST, expressions=[this], nested=True) - - index = self._index - - # Postgres supports the INT ARRAY[3] syntax as a synonym for INT[3] - matched_array = self._match(TokenType.ARRAY) - - while self._curr: - datatype_token = self._prev.token_type - matched_l_bracket = self._match(TokenType.L_BRACKET) - - if (not matched_l_bracket and not matched_array) or ( - datatype_token == TokenType.ARRAY and self._match(TokenType.R_BRACKET) - ): - # Postgres allows casting empty arrays such as ARRAY[]::INT[], - # not to be confused with the fixed size array parsing - break - - matched_array = False - values = self._parse_csv(self._parse_assignment) or None - if ( - values - and not schema - and ( - not self.dialect.SUPPORTS_FIXED_SIZE_ARRAYS or datatype_token == TokenType.ARRAY - ) - ): - # Retreating here means that we should not parse the following values as part of the data type, e.g. in DuckDB - # ARRAY[1] should retreat and instead be parsed into exp.Array in contrast to INT[x][y] which denotes a fixed-size array data type - self._retreat(index) - break - - this = exp.DataType( - this=exp.DataType.Type.ARRAY, expressions=[this], values=values, nested=True - ) - self._match(TokenType.R_BRACKET) - - if self.TYPE_CONVERTERS and isinstance(this.this, exp.DataType.Type): - converter = self.TYPE_CONVERTERS.get(this.this) - if converter: - this = converter(t.cast(exp.DataType, this)) - - return this - - def _parse_struct_types(self, type_required: bool = False) -> t.Optional[exp.Expression]: - index = self._index - - if ( - self._curr - and self._next - and self._curr.token_type in self.TYPE_TOKENS - and self._next.token_type in self.TYPE_TOKENS - ): - # Takes care of special cases like `STRUCT>` where the identifier is also a - # type token. Without this, the list will be parsed as a type and we'll eventually crash - this = self._parse_id_var() - else: - this = ( - self._parse_type(parse_interval=False, fallback_to_identifier=True) - or self._parse_id_var() - ) - - self._match(TokenType.COLON) - - if ( - type_required - and not isinstance(this, exp.DataType) - and not self._match_set(self.TYPE_TOKENS, advance=False) - ): - self._retreat(index) - return self._parse_types() - - return self._parse_column_def(this) - - def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - if not self._match_text_seq("AT", "TIME", "ZONE"): - return this - return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) - - def _parse_column(self) -> t.Optional[exp.Expression]: - this = self._parse_column_reference() - column = self._parse_column_ops(this) if this else self._parse_bracket(this) - - if self.dialect.SUPPORTS_COLUMN_JOIN_MARKS and column: - column.set("join_mark", self._match(TokenType.JOIN_MARKER)) - - return column - - def _parse_column_reference(self) -> t.Optional[exp.Expression]: - this = self._parse_field() - if ( - not this - and self._match(TokenType.VALUES, advance=False) - and self.VALUES_FOLLOWED_BY_PAREN - and (not self._next or self._next.token_type != TokenType.L_PAREN) - ): - this = self._parse_id_var() - - if isinstance(this, exp.Identifier): - # We bubble up comments from the Identifier to the Column - this = self.expression(exp.Column, comments=this.pop_comments(), this=this) - - return this - - def _parse_colon_as_variant_extract( - self, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: - casts = [] - json_path = [] - escape = None - - while self._match(TokenType.COLON): - start_index = self._index - - # Snowflake allows reserved keywords as json keys but advance_any() excludes TokenType.SELECT from any_tokens=True - path = self._parse_column_ops( - self._parse_field(any_token=True, tokens=(TokenType.SELECT,)) - ) - - # The cast :: operator has a lower precedence than the extraction operator :, so - # we rearrange the AST appropriately to avoid casting the JSON path - while isinstance(path, exp.Cast): - casts.append(path.to) - path = path.this - - if casts: - dcolon_offset = next( - i - for i, t in enumerate(self._tokens[start_index:]) - if t.token_type == TokenType.DCOLON - ) - end_token = self._tokens[start_index + dcolon_offset - 1] - else: - end_token = self._prev - - if path: - # Escape single quotes from Snowflake's colon extraction (e.g. col:"a'b") as - # it'll roundtrip to a string literal in GET_PATH - if isinstance(path, exp.Identifier) and path.quoted: - escape = True - - json_path.append(self._find_sql(self._tokens[start_index], end_token)) - - # The VARIANT extract in Snowflake/Databricks is parsed as a JSONExtract; Snowflake uses the json_path in GET_PATH() while - # Databricks transforms it back to the colon/dot notation - if json_path: - json_path_expr = self.dialect.to_json_path(exp.Literal.string(".".join(json_path))) - - if json_path_expr: - json_path_expr.set("escape", escape) - - this = self.expression( - exp.JSONExtract, - this=this, - expression=json_path_expr, - variant_extract=True, - ) - - while casts: - this = self.expression(exp.Cast, this=this, to=casts.pop()) - - return this - - def _parse_dcolon(self) -> t.Optional[exp.Expression]: - return self._parse_types() - - def _parse_column_ops(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - this = self._parse_bracket(this) - - while self._match_set(self.COLUMN_OPERATORS): - op_token = self._prev.token_type - op = self.COLUMN_OPERATORS.get(op_token) - - if op_token in (TokenType.DCOLON, TokenType.DOTCOLON): - field = self._parse_dcolon() - if not field: - self.raise_error("Expected type") - elif op and self._curr: - field = self._parse_column_reference() or self._parse_bracket() - if isinstance(field, exp.Column) and self._match(TokenType.DOT, advance=False): - field = self._parse_column_ops(field) - else: - field = self._parse_field(any_token=True, anonymous_func=True) - - # Function calls can be qualified, e.g., x.y.FOO() - # This converts the final AST to a series of Dots leading to the function call - # https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-reference#function_call_rules - if isinstance(field, (exp.Func, exp.Window)) and this: - this = this.transform( - lambda n: n.to_dot(include_dots=False) if isinstance(n, exp.Column) else n - ) - - if op: - this = op(self, this, field) - elif isinstance(this, exp.Column) and not this.args.get("catalog"): - this = self.expression( - exp.Column, - comments=this.comments, - this=field, - table=this.this, - db=this.args.get("table"), - catalog=this.args.get("db"), - ) - elif isinstance(field, exp.Window): - # Move the exp.Dot's to the window's function - window_func = self.expression(exp.Dot, this=this, expression=field.this) - field.set("this", window_func) - this = field - else: - this = self.expression(exp.Dot, this=this, expression=field) - - if field and field.comments: - t.cast(exp.Expression, this).add_comments(field.pop_comments()) - - this = self._parse_bracket(this) - - return self._parse_colon_as_variant_extract(this) if self.COLON_IS_VARIANT_EXTRACT else this - - def _parse_paren(self) -> t.Optional[exp.Expression]: - if not self._match(TokenType.L_PAREN): - return None - - comments = self._prev_comments - query = self._parse_select() - - if query: - expressions = [query] - else: - expressions = self._parse_expressions() - - this = self._parse_query_modifiers(seq_get(expressions, 0)) - - if not this and self._match(TokenType.R_PAREN, advance=False): - this = self.expression(exp.Tuple) - elif isinstance(this, exp.UNWRAPPED_QUERIES): - this = self._parse_subquery(this=this, parse_alias=False) - elif isinstance(this, exp.Subquery): - this = self._parse_subquery(this=self._parse_set_operations(this), parse_alias=False) - elif len(expressions) > 1 or self._prev.token_type == TokenType.COMMA: - this = self.expression(exp.Tuple, expressions=expressions) - else: - this = self.expression(exp.Paren, this=this) - - if this: - this.add_comments(comments) - - self._match_r_paren(expression=this) - return this - - def _parse_primary(self) -> t.Optional[exp.Expression]: - if self._match_set(self.PRIMARY_PARSERS): - token_type = self._prev.token_type - primary = self.PRIMARY_PARSERS[token_type](self, self._prev) - - if token_type == TokenType.STRING: - expressions = [primary] - while self._match(TokenType.STRING): - expressions.append(exp.Literal.string(self._prev.text)) - - if len(expressions) > 1: - return self.expression(exp.Concat, expressions=expressions) - - return primary - - if self._match_pair(TokenType.DOT, TokenType.NUMBER): - return exp.Literal.number(f"0.{self._prev.text}") - - return self._parse_paren() - - def _parse_field( - self, - any_token: bool = False, - tokens: t.Optional[t.Collection[TokenType]] = None, - anonymous_func: bool = False, - ) -> t.Optional[exp.Expression]: - if anonymous_func: - field = ( - self._parse_function(anonymous=anonymous_func, any_token=any_token) - or self._parse_primary() - ) - else: - field = self._parse_primary() or self._parse_function( - anonymous=anonymous_func, any_token=any_token - ) - return field or self._parse_id_var(any_token=any_token, tokens=tokens) - - def _parse_function( - self, - functions: t.Optional[t.Dict[str, t.Callable]] = None, - anonymous: bool = False, - optional_parens: bool = True, - any_token: bool = False, - ) -> t.Optional[exp.Expression]: - # This allows us to also parse {fn } syntax (Snowflake, MySQL support this) - # See: https://community.snowflake.com/s/article/SQL-Escape-Sequences - fn_syntax = False - if ( - self._match(TokenType.L_BRACE, advance=False) - and self._next - and self._next.text.upper() == "FN" - ): - self._advance(2) - fn_syntax = True - - func = self._parse_function_call( - functions=functions, - anonymous=anonymous, - optional_parens=optional_parens, - any_token=any_token, - ) - - if fn_syntax: - self._match(TokenType.R_BRACE) - - return func - - def _parse_function_call( - self, - functions: t.Optional[t.Dict[str, t.Callable]] = None, - anonymous: bool = False, - optional_parens: bool = True, - any_token: bool = False, - ) -> t.Optional[exp.Expression]: - if not self._curr: - return None - - comments = self._curr.comments - token = self._curr - token_type = self._curr.token_type - this = self._curr.text - upper = this.upper() - - parser = self.NO_PAREN_FUNCTION_PARSERS.get(upper) - if optional_parens and parser and token_type not in self.INVALID_FUNC_NAME_TOKENS: - self._advance() - return self._parse_window(parser(self)) - - if not self._next or self._next.token_type != TokenType.L_PAREN: - if optional_parens and token_type in self.NO_PAREN_FUNCTIONS: - self._advance() - return self.expression(self.NO_PAREN_FUNCTIONS[token_type]) - - return None - - if any_token: - if token_type in self.RESERVED_TOKENS: - return None - elif token_type not in self.FUNC_TOKENS: - return None - - self._advance(2) - - parser = self.FUNCTION_PARSERS.get(upper) - if parser and not anonymous: - this = parser(self) - else: - subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) - - if subquery_predicate and self._curr.token_type in (TokenType.SELECT, TokenType.WITH): - this = self.expression( - subquery_predicate, comments=comments, this=self._parse_select() - ) - self._match_r_paren() - return this - - if functions is None: - functions = self.FUNCTIONS - - function = functions.get(upper) - known_function = function and not anonymous - - alias = not known_function or upper in self.FUNCTIONS_WITH_ALIASED_ARGS - args = self._parse_csv(lambda: self._parse_lambda(alias=alias)) - - post_func_comments = self._curr and self._curr.comments - if known_function and post_func_comments: - # If the user-inputted comment "/* sqlglot.anonymous */" is following the function - # call we'll construct it as exp.Anonymous, even if it's "known" - if any( - comment.lstrip().startswith(exp.SQLGLOT_ANONYMOUS) - for comment in post_func_comments - ): - known_function = False - - if alias and known_function: - args = self._kv_to_prop_eq(args) - - if known_function: - func_builder = t.cast(t.Callable, function) - - if "dialect" in func_builder.__code__.co_varnames: - func = func_builder(args, dialect=self.dialect) - else: - func = func_builder(args) - - func = self.validate_expression(func, args) - if self.dialect.PRESERVE_ORIGINAL_NAMES: - func.meta["name"] = this - - this = func - else: - if token_type == TokenType.IDENTIFIER: - this = exp.Identifier(this=this, quoted=True).update_positions(token) - - this = self.expression(exp.Anonymous, this=this, expressions=args) - this = this.update_positions(token) - - if isinstance(this, exp.Expression): - this.add_comments(comments) - - self._match_r_paren(this) - return self._parse_window(this) - - def _to_prop_eq(self, expression: exp.Expression, index: int) -> exp.Expression: - return expression - - def _kv_to_prop_eq(self, expressions: t.List[exp.Expression]) -> t.List[exp.Expression]: - transformed = [] - - for index, e in enumerate(expressions): - if isinstance(e, self.KEY_VALUE_DEFINITIONS): - if isinstance(e, exp.Alias): - e = self.expression(exp.PropertyEQ, this=e.args.get("alias"), expression=e.this) - - if not isinstance(e, exp.PropertyEQ): - e = self.expression( - exp.PropertyEQ, this=exp.to_identifier(e.this.name), expression=e.expression - ) - - if isinstance(e.this, exp.Column): - e.this.replace(e.this.this) - else: - e = self._to_prop_eq(e, index) - - transformed.append(e) - - return transformed - - def _parse_user_defined_function_expression(self) -> t.Optional[exp.Expression]: - return self._parse_statement() - - def _parse_function_parameter(self) -> t.Optional[exp.Expression]: - return self._parse_column_def(this=self._parse_id_var(), computed_column=False) - - def _parse_user_defined_function( - self, kind: t.Optional[TokenType] = None - ) -> t.Optional[exp.Expression]: - this = self._parse_table_parts(schema=True) - - if not self._match(TokenType.L_PAREN): - return this - - expressions = self._parse_csv(self._parse_function_parameter) - self._match_r_paren() - return self.expression( - exp.UserDefinedFunction, this=this, expressions=expressions, wrapped=True - ) - - def _parse_introducer(self, token: Token) -> exp.Introducer | exp.Identifier: - literal = self._parse_primary() - if literal: - return self.expression(exp.Introducer, this=token.text, expression=literal) - - return self._identifier_expression(token) - - def _parse_session_parameter(self) -> exp.SessionParameter: - kind = None - this = self._parse_id_var() or self._parse_primary() - - if this and self._match(TokenType.DOT): - kind = this.name - this = self._parse_var() or self._parse_primary() - - return self.expression(exp.SessionParameter, this=this, kind=kind) - - def _parse_lambda_arg(self) -> t.Optional[exp.Expression]: - return self._parse_id_var() - - def _parse_lambda(self, alias: bool = False) -> t.Optional[exp.Expression]: - index = self._index - - if self._match(TokenType.L_PAREN): - expressions = t.cast( - t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_lambda_arg) - ) - - if not self._match(TokenType.R_PAREN): - self._retreat(index) - else: - expressions = [self._parse_lambda_arg()] - - if self._match_set(self.LAMBDAS): - return self.LAMBDAS[self._prev.token_type](self, expressions) - - self._retreat(index) - - this: t.Optional[exp.Expression] - - if self._match(TokenType.DISTINCT): - this = self.expression( - exp.Distinct, expressions=self._parse_csv(self._parse_assignment) - ) - else: - this = self._parse_select_or_expression(alias=alias) - - return self._parse_limit( - self._parse_order(self._parse_having_max(self._parse_respect_or_ignore_nulls(this))) - ) - - def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: - index = self._index - if not self._match(TokenType.L_PAREN): - return this - - # Disambiguate between schema and subquery/CTE, e.g. in INSERT INTO table (), - # expr can be of both types - if self._match_set(self.SELECT_START_TOKENS): - self._retreat(index) - return this - args = self._parse_csv(lambda: self._parse_constraint() or self._parse_field_def()) - self._match_r_paren() - return self.expression(exp.Schema, this=this, expressions=args) - - def _parse_field_def(self) -> t.Optional[exp.Expression]: - return self._parse_column_def(self._parse_field(any_token=True)) - - def _parse_column_def( - self, this: t.Optional[exp.Expression], computed_column: bool = True - ) -> t.Optional[exp.Expression]: - # column defs are not really columns, they're identifiers - if isinstance(this, exp.Column): - this = this.this - - if not computed_column: - self._match(TokenType.ALIAS) - - kind = self._parse_types(schema=True) - - if self._match_text_seq("FOR", "ORDINALITY"): - return self.expression(exp.ColumnDef, this=this, ordinality=True) - - constraints: t.List[exp.Expression] = [] - - if (not kind and self._match(TokenType.ALIAS)) or self._match_texts( - ("ALIAS", "MATERIALIZED") - ): - persisted = self._prev.text.upper() == "MATERIALIZED" - constraint_kind = exp.ComputedColumnConstraint( - this=self._parse_assignment(), - persisted=persisted or self._match_text_seq("PERSISTED"), - not_null=self._match_pair(TokenType.NOT, TokenType.NULL), - ) - constraints.append(self.expression(exp.ColumnConstraint, kind=constraint_kind)) - elif ( - kind - and self._match(TokenType.ALIAS, advance=False) - and ( - not self.WRAPPED_TRANSFORM_COLUMN_CONSTRAINT - or (self._next and self._next.token_type == TokenType.L_PAREN) - ) - ): - self._advance() - constraints.append( - self.expression( - exp.ColumnConstraint, - kind=exp.ComputedColumnConstraint( - this=self._parse_disjunction(), - persisted=self._match_texts(("STORED", "VIRTUAL")) - and self._prev.text.upper() == "STORED", - ), - ) - ) - - while True: - constraint = self._parse_column_constraint() - if not constraint: - break - constraints.append(constraint) - - if not kind and not constraints: - return this - - return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) - - def _parse_auto_increment( - self, - ) -> exp.GeneratedAsIdentityColumnConstraint | exp.AutoIncrementColumnConstraint: - start = None - increment = None - - if self._match(TokenType.L_PAREN, advance=False): - args = self._parse_wrapped_csv(self._parse_bitwise) - start = seq_get(args, 0) - increment = seq_get(args, 1) - elif self._match_text_seq("START"): - start = self._parse_bitwise() - self._match_text_seq("INCREMENT") - increment = self._parse_bitwise() - - if start and increment: - return exp.GeneratedAsIdentityColumnConstraint( - start=start, increment=increment, this=False - ) - - return exp.AutoIncrementColumnConstraint() - - def _parse_auto_property(self) -> t.Optional[exp.AutoRefreshProperty]: - if not self._match_text_seq("REFRESH"): - self._retreat(self._index - 1) - return None - return self.expression(exp.AutoRefreshProperty, this=self._parse_var(upper=True)) - - def _parse_compress(self) -> exp.CompressColumnConstraint: - if self._match(TokenType.L_PAREN, advance=False): - return self.expression( - exp.CompressColumnConstraint, this=self._parse_wrapped_csv(self._parse_bitwise) - ) - - return self.expression(exp.CompressColumnConstraint, this=self._parse_bitwise()) - - def _parse_generated_as_identity( - self, - ) -> ( - exp.GeneratedAsIdentityColumnConstraint - | exp.ComputedColumnConstraint - | exp.GeneratedAsRowColumnConstraint - ): - if self._match_text_seq("BY", "DEFAULT"): - on_null = self._match_pair(TokenType.ON, TokenType.NULL) - this = self.expression( - exp.GeneratedAsIdentityColumnConstraint, this=False, on_null=on_null - ) - else: - self._match_text_seq("ALWAYS") - this = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True) - - self._match(TokenType.ALIAS) - - if self._match_text_seq("ROW"): - start = self._match_text_seq("START") - if not start: - self._match(TokenType.END) - hidden = self._match_text_seq("HIDDEN") - return self.expression(exp.GeneratedAsRowColumnConstraint, start=start, hidden=hidden) - - identity = self._match_text_seq("IDENTITY") - - if self._match(TokenType.L_PAREN): - if self._match(TokenType.START_WITH): - this.set("start", self._parse_bitwise()) - if self._match_text_seq("INCREMENT", "BY"): - this.set("increment", self._parse_bitwise()) - if self._match_text_seq("MINVALUE"): - this.set("minvalue", self._parse_bitwise()) - if self._match_text_seq("MAXVALUE"): - this.set("maxvalue", self._parse_bitwise()) - - if self._match_text_seq("CYCLE"): - this.set("cycle", True) - elif self._match_text_seq("NO", "CYCLE"): - this.set("cycle", False) - - if not identity: - this.set("expression", self._parse_range()) - elif not this.args.get("start") and self._match(TokenType.NUMBER, advance=False): - args = self._parse_csv(self._parse_bitwise) - this.set("start", seq_get(args, 0)) - this.set("increment", seq_get(args, 1)) - - self._match_r_paren() - - return this - - def _parse_inline(self) -> exp.InlineLengthColumnConstraint: - self._match_text_seq("LENGTH") - return self.expression(exp.InlineLengthColumnConstraint, this=self._parse_bitwise()) - - def _parse_not_constraint(self) -> t.Optional[exp.Expression]: - if self._match_text_seq("NULL"): - return self.expression(exp.NotNullColumnConstraint) - if self._match_text_seq("CASESPECIFIC"): - return self.expression(exp.CaseSpecificColumnConstraint, not_=True) - if self._match_text_seq("FOR", "REPLICATION"): - return self.expression(exp.NotForReplicationColumnConstraint) - - # Unconsume the `NOT` token - self._retreat(self._index - 1) - return None - - def _parse_column_constraint(self) -> t.Optional[exp.Expression]: - this = self._match(TokenType.CONSTRAINT) and self._parse_id_var() - - procedure_option_follows = ( - self._match(TokenType.WITH, advance=False) - and self._next - and self._next.text.upper() in self.PROCEDURE_OPTIONS - ) - - if not procedure_option_follows and self._match_texts(self.CONSTRAINT_PARSERS): - return self.expression( - exp.ColumnConstraint, - this=this, - kind=self.CONSTRAINT_PARSERS[self._prev.text.upper()](self), - ) - - return this - - def _parse_constraint(self) -> t.Optional[exp.Expression]: - if not self._match(TokenType.CONSTRAINT): - return self._parse_unnamed_constraint(constraints=self.SCHEMA_UNNAMED_CONSTRAINTS) - - return self.expression( - exp.Constraint, - this=self._parse_id_var(), - expressions=self._parse_unnamed_constraints(), - ) - - def _parse_unnamed_constraints(self) -> t.List[exp.Expression]: - constraints = [] - while True: - constraint = self._parse_unnamed_constraint() or self._parse_function() - if not constraint: - break - constraints.append(constraint) - - return constraints - - def _parse_unnamed_constraint( - self, constraints: t.Optional[t.Collection[str]] = None - ) -> t.Optional[exp.Expression]: - if self._match(TokenType.IDENTIFIER, advance=False) or not self._match_texts( - constraints or self.CONSTRAINT_PARSERS - ): - return None - - constraint = self._prev.text.upper() - if constraint not in self.CONSTRAINT_PARSERS: - self.raise_error(f"No parser found for schema constraint {constraint}.") - - return self.CONSTRAINT_PARSERS[constraint](self) - - def _parse_unique_key(self) -> t.Optional[exp.Expression]: - return self._parse_id_var(any_token=False) - - def _parse_unique(self) -> exp.UniqueColumnConstraint: - self._match_text_seq("KEY") - return self.expression( - exp.UniqueColumnConstraint, - nulls=self._match_text_seq("NULLS", "NOT", "DISTINCT"), - this=self._parse_schema(self._parse_unique_key()), - index_type=self._match(TokenType.USING) and self._advance_any() and self._prev.text, - on_conflict=self._parse_on_conflict(), - options=self._parse_key_constraint_options(), - ) - - def _parse_key_constraint_options(self) -> t.List[str]: - options = [] - while True: - if not self._curr: - break - - if self._match(TokenType.ON): - action = None - on = self._advance_any() and self._prev.text - - if self._match_text_seq("NO", "ACTION"): - action = "NO ACTION" - elif self._match_text_seq("CASCADE"): - action = "CASCADE" - elif self._match_text_seq("RESTRICT"): - action = "RESTRICT" - elif self._match_pair(TokenType.SET, TokenType.NULL): - action = "SET NULL" - elif self._match_pair(TokenType.SET, TokenType.DEFAULT): - action = "SET DEFAULT" - else: - self.raise_error("Invalid key constraint") - - options.append(f"ON {on} {action}") - else: - var = self._parse_var_from_options( - self.KEY_CONSTRAINT_OPTIONS, raise_unmatched=False - ) - if not var: - break - options.append(var.name) - - return options - - def _parse_references(self, match: bool = True) -> t.Optional[exp.Reference]: - if match and not self._match(TokenType.REFERENCES): - return None - - expressions = None - this = self._parse_table(schema=True) - options = self._parse_key_constraint_options() - return self.expression(exp.Reference, this=this, expressions=expressions, options=options) - - def _parse_foreign_key(self) -> exp.ForeignKey: - expressions = ( - self._parse_wrapped_id_vars() - if not self._match(TokenType.REFERENCES, advance=False) - else None - ) - reference = self._parse_references() - on_options = {} - - while self._match(TokenType.ON): - if not self._match_set((TokenType.DELETE, TokenType.UPDATE)): - self.raise_error("Expected DELETE or UPDATE") - - kind = self._prev.text.lower() - - if self._match_text_seq("NO", "ACTION"): - action = "NO ACTION" - elif self._match(TokenType.SET): - self._match_set((TokenType.NULL, TokenType.DEFAULT)) - action = "SET " + self._prev.text.upper() - else: - self._advance() - action = self._prev.text.upper() - - on_options[kind] = action - - return self.expression( - exp.ForeignKey, - expressions=expressions, - reference=reference, - options=self._parse_key_constraint_options(), - **on_options, # type: ignore - ) - - def _parse_primary_key_part(self) -> t.Optional[exp.Expression]: - return self._parse_ordered() or self._parse_field() - - def _parse_period_for_system_time(self) -> t.Optional[exp.PeriodForSystemTimeConstraint]: - if not self._match(TokenType.TIMESTAMP_SNAPSHOT): - self._retreat(self._index - 1) - return None - - id_vars = self._parse_wrapped_id_vars() - return self.expression( - exp.PeriodForSystemTimeConstraint, - this=seq_get(id_vars, 0), - expression=seq_get(id_vars, 1), - ) - - def _parse_primary_key( - self, wrapped_optional: bool = False, in_props: bool = False - ) -> exp.PrimaryKeyColumnConstraint | exp.PrimaryKey: - desc = ( - self._match_set((TokenType.ASC, TokenType.DESC)) - and self._prev.token_type == TokenType.DESC - ) - - if not in_props and not self._match(TokenType.L_PAREN, advance=False): - return self.expression( - exp.PrimaryKeyColumnConstraint, - desc=desc, - options=self._parse_key_constraint_options(), - ) - - expressions = self._parse_wrapped_csv( - self._parse_primary_key_part, optional=wrapped_optional - ) - options = self._parse_key_constraint_options() - return self.expression(exp.PrimaryKey, expressions=expressions, options=options) - - def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]: - return self._parse_slice(self._parse_alias(self._parse_assignment(), explicit=True)) - - def _parse_odbc_datetime_literal(self) -> exp.Expression: - """ - Parses a datetime column in ODBC format. We parse the column into the corresponding - types, for example `{d'yyyy-mm-dd'}` will be parsed as a `Date` column, exactly the - same as we did for `DATE('yyyy-mm-dd')`. - - Reference: - https://learn.microsoft.com/en-us/sql/odbc/reference/develop-app/date-time-and-timestamp-literals - """ - self._match(TokenType.VAR) - exp_class = self.ODBC_DATETIME_LITERALS[self._prev.text.lower()] - expression = self.expression(exp_class=exp_class, this=self._parse_string()) - if not self._match(TokenType.R_BRACE): - self.raise_error("Expected }") - return expression - - def _parse_bracket(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: - if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)): - return this - - bracket_kind = self._prev.token_type - if ( - bracket_kind == TokenType.L_BRACE - and self._curr - and self._curr.token_type == TokenType.VAR - and self._curr.text.lower() in self.ODBC_DATETIME_LITERALS - ): - return self._parse_odbc_datetime_literal() - - expressions = self._parse_csv( - lambda: self._parse_bracket_key_value(is_map=bracket_kind == TokenType.L_BRACE) - ) - - if bracket_kind == TokenType.L_BRACKET and not self._match(TokenType.R_BRACKET): - self.raise_error("Expected ]") - elif bracket_kind == TokenType.L_BRACE and not self._match(TokenType.R_BRACE): - self.raise_error("Expected }") - - # https://duckdb.org/docs/sql/data_types/struct.html#creating-structs - if bracket_kind == TokenType.L_BRACE: - this = self.expression(exp.Struct, expressions=self._kv_to_prop_eq(expressions)) - elif not this: - this = build_array_constructor( - exp.Array, args=expressions, bracket_kind=bracket_kind, dialect=self.dialect - ) - else: - constructor_type = self.ARRAY_CONSTRUCTORS.get(this.name.upper()) - if constructor_type: - return build_array_constructor( - constructor_type, - args=expressions, - bracket_kind=bracket_kind, - dialect=self.dialect, - ) - - expressions = apply_index_offset( - this, expressions, -self.dialect.INDEX_OFFSET, dialect=self.dialect - ) - this = self.expression( - exp.Bracket, - this=this, - expressions=expressions, - comments=this.pop_comments(), - ) - - self._add_comments(this) - return self._parse_bracket(this) - - def _parse_slice(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - if self._match(TokenType.COLON): - return self.expression(exp.Slice, this=this, expression=self._parse_assignment()) - return this - - def _parse_case(self) -> t.Optional[exp.Expression]: - ifs = [] - default = None - - comments = self._prev_comments - expression = self._parse_assignment() - - while self._match(TokenType.WHEN): - this = self._parse_assignment() - self._match(TokenType.THEN) - then = self._parse_assignment() - ifs.append(self.expression(exp.If, this=this, true=then)) - - if self._match(TokenType.ELSE): - default = self._parse_assignment() - - if not self._match(TokenType.END): - if isinstance(default, exp.Interval) and default.this.sql().upper() == "END": - default = exp.column("interval") - else: - self.raise_error("Expected END after CASE", self._prev) - - return self.expression( - exp.Case, comments=comments, this=expression, ifs=ifs, default=default - ) - - def _parse_if(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.L_PAREN): - args = self._parse_csv( - lambda: self._parse_alias(self._parse_assignment(), explicit=True) - ) - this = self.validate_expression(exp.If.from_arg_list(args), args) - self._match_r_paren() - else: - index = self._index - 1 - - if self.NO_PAREN_IF_COMMANDS and index == 0: - return self._parse_as_command(self._prev) - - condition = self._parse_assignment() - - if not condition: - self._retreat(index) - return None - - self._match(TokenType.THEN) - true = self._parse_assignment() - false = self._parse_assignment() if self._match(TokenType.ELSE) else None - self._match(TokenType.END) - this = self.expression(exp.If, this=condition, true=true, false=false) - - return this - - def _parse_next_value_for(self) -> t.Optional[exp.Expression]: - if not self._match_text_seq("VALUE", "FOR"): - self._retreat(self._index - 1) - return None - - return self.expression( - exp.NextValueFor, - this=self._parse_column(), - order=self._match(TokenType.OVER) and self._parse_wrapped(self._parse_order), - ) - - def _parse_extract(self) -> exp.Extract | exp.Anonymous: - this = self._parse_function() or self._parse_var_or_string(upper=True) - - if self._match(TokenType.FROM): - return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) - - if not self._match(TokenType.COMMA): - self.raise_error("Expected FROM or comma after EXTRACT", self._prev) - - return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) - - def _parse_gap_fill(self) -> exp.GapFill: - self._match(TokenType.TABLE) - this = self._parse_table() - - self._match(TokenType.COMMA) - args = [this, *self._parse_csv(self._parse_lambda)] - - gap_fill = exp.GapFill.from_arg_list(args) - return self.validate_expression(gap_fill, args) - - def _parse_cast(self, strict: bool, safe: t.Optional[bool] = None) -> exp.Expression: - this = self._parse_assignment() - - if not self._match(TokenType.ALIAS): - if self._match(TokenType.COMMA): - return self.expression(exp.CastToStrType, this=this, to=self._parse_string()) - - self.raise_error("Expected AS after CAST") - - fmt = None - to = self._parse_types() - - default = self._match(TokenType.DEFAULT) - if default: - default = self._parse_bitwise() - self._match_text_seq("ON", "CONVERSION", "ERROR") - - if self._match_set((TokenType.FORMAT, TokenType.COMMA)): - fmt_string = self._parse_string() - fmt = self._parse_at_time_zone(fmt_string) - - if not to: - to = exp.DataType.build(exp.DataType.Type.UNKNOWN) - if to.this in exp.DataType.TEMPORAL_TYPES: - this = self.expression( - exp.StrToDate if to.this == exp.DataType.Type.DATE else exp.StrToTime, - this=this, - format=exp.Literal.string( - format_time( - fmt_string.this if fmt_string else "", - self.dialect.FORMAT_MAPPING or self.dialect.TIME_MAPPING, - self.dialect.FORMAT_TRIE or self.dialect.TIME_TRIE, - ) - ), - safe=safe, - ) - - if isinstance(fmt, exp.AtTimeZone) and isinstance(this, exp.StrToTime): - this.set("zone", fmt.args["zone"]) - return this - elif not to: - self.raise_error("Expected TYPE after CAST") - elif isinstance(to, exp.Identifier): - to = exp.DataType.build(to.name, udt=True) - elif to.this == exp.DataType.Type.CHAR: - if self._match(TokenType.CHARACTER_SET): - to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) - - return self.expression( - exp.Cast if strict else exp.TryCast, - this=this, - to=to, - format=fmt, - safe=safe, - action=self._parse_var_from_options(self.CAST_ACTIONS, raise_unmatched=False), - default=default, - ) - - def _parse_string_agg(self) -> exp.GroupConcat: - if self._match(TokenType.DISTINCT): - args: t.List[t.Optional[exp.Expression]] = [ - self.expression(exp.Distinct, expressions=[self._parse_assignment()]) - ] - if self._match(TokenType.COMMA): - args.extend(self._parse_csv(self._parse_assignment)) - else: - args = self._parse_csv(self._parse_assignment) # type: ignore - - if self._match_text_seq("ON", "OVERFLOW"): - # trino: LISTAGG(expression [, separator] [ON OVERFLOW overflow_behavior]) - if self._match_text_seq("ERROR"): - on_overflow: t.Optional[exp.Expression] = exp.var("ERROR") - else: - self._match_text_seq("TRUNCATE") - on_overflow = self.expression( - exp.OverflowTruncateBehavior, - this=self._parse_string(), - with_count=( - self._match_text_seq("WITH", "COUNT") - or not self._match_text_seq("WITHOUT", "COUNT") - ), - ) - else: - on_overflow = None - - index = self._index - if not self._match(TokenType.R_PAREN) and args: - # postgres: STRING_AGG([DISTINCT] expression, separator [ORDER BY expression1 {ASC | DESC} [, ...]]) - # bigquery: STRING_AGG([DISTINCT] expression [, separator] [ORDER BY key [{ASC | DESC}] [, ... ]] [LIMIT n]) - # The order is parsed through `this` as a canonicalization for WITHIN GROUPs - args[0] = self._parse_limit(this=self._parse_order(this=args[0])) - return self.expression(exp.GroupConcat, this=args[0], separator=seq_get(args, 1)) - - # Checks if we can parse an order clause: WITHIN GROUP (ORDER BY [ASC | DESC]). - # This is done "manually", instead of letting _parse_window parse it into an exp.WithinGroup node, so that - # the STRING_AGG call is parsed like in MySQL / SQLite and can thus be transpiled more easily to them. - if not self._match_text_seq("WITHIN", "GROUP"): - self._retreat(index) - return self.validate_expression(exp.GroupConcat.from_arg_list(args), args) - - # The corresponding match_r_paren will be called in parse_function (caller) - self._match_l_paren() - - return self.expression( - exp.GroupConcat, - this=self._parse_order(this=seq_get(args, 0)), - separator=seq_get(args, 1), - on_overflow=on_overflow, - ) - - def _parse_convert( - self, strict: bool, safe: t.Optional[bool] = None - ) -> t.Optional[exp.Expression]: - this = self._parse_bitwise() - - if self._match(TokenType.USING): - to: t.Optional[exp.Expression] = self.expression( - exp.CharacterSet, this=self._parse_var() - ) - elif self._match(TokenType.COMMA): - to = self._parse_types() - else: - to = None - - return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to, safe=safe) - - def _parse_xml_table(self) -> exp.XMLTable: - namespaces = None - passing = None - columns = None - - if self._match_text_seq("XMLNAMESPACES", "("): - namespaces = self._parse_xml_namespace() - self._match_text_seq(")", ",") - - this = self._parse_string() - - if self._match_text_seq("PASSING"): - # The BY VALUE keywords are optional and are provided for semantic clarity - self._match_text_seq("BY", "VALUE") - passing = self._parse_csv(self._parse_column) - - by_ref = self._match_text_seq("RETURNING", "SEQUENCE", "BY", "REF") - - if self._match_text_seq("COLUMNS"): - columns = self._parse_csv(self._parse_field_def) - - return self.expression( - exp.XMLTable, - this=this, - namespaces=namespaces, - passing=passing, - columns=columns, - by_ref=by_ref, - ) - - def _parse_xml_namespace(self) -> t.List[exp.XMLNamespace]: - namespaces = [] - - while True: - if self._match(TokenType.DEFAULT): - uri = self._parse_string() - else: - uri = self._parse_alias(self._parse_string()) - namespaces.append(self.expression(exp.XMLNamespace, this=uri)) - if not self._match(TokenType.COMMA): - break - - return namespaces - - def _parse_decode(self) -> t.Optional[exp.Decode | exp.Case]: - """ - There are generally two variants of the DECODE function: - - - DECODE(bin, charset) - - DECODE(expression, search, result [, search, result] ... [, default]) - - The second variant will always be parsed into a CASE expression. Note that NULL - needs special treatment, since we need to explicitly check for it with `IS NULL`, - instead of relying on pattern matching. - """ - args = self._parse_csv(self._parse_assignment) - - if len(args) < 3: - return self.expression(exp.Decode, this=seq_get(args, 0), charset=seq_get(args, 1)) - - expression, *expressions = args - if not expression: - return None - - ifs = [] - for search, result in zip(expressions[::2], expressions[1::2]): - if not search or not result: - return None - - if isinstance(search, exp.Literal): - ifs.append( - exp.If(this=exp.EQ(this=expression.copy(), expression=search), true=result) - ) - elif isinstance(search, exp.Null): - ifs.append( - exp.If(this=exp.Is(this=expression.copy(), expression=exp.Null()), true=result) - ) - else: - cond = exp.or_( - exp.EQ(this=expression.copy(), expression=search), - exp.and_( - exp.Is(this=expression.copy(), expression=exp.Null()), - exp.Is(this=search.copy(), expression=exp.Null()), - copy=False, - ), - copy=False, - ) - ifs.append(exp.If(this=cond, true=result)) - - return exp.Case(ifs=ifs, default=expressions[-1] if len(expressions) % 2 == 1 else None) - - def _parse_json_key_value(self) -> t.Optional[exp.JSONKeyValue]: - self._match_text_seq("KEY") - key = self._parse_column() - self._match_set(self.JSON_KEY_VALUE_SEPARATOR_TOKENS) - self._match_text_seq("VALUE") - value = self._parse_bitwise() - - if not key and not value: - return None - return self.expression(exp.JSONKeyValue, this=key, expression=value) - - def _parse_format_json(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - if not this or not self._match_text_seq("FORMAT", "JSON"): - return this - - return self.expression(exp.FormatJson, this=this) - - def _parse_on_condition(self) -> t.Optional[exp.OnCondition]: - # MySQL uses "X ON EMPTY Y ON ERROR" (e.g. JSON_VALUE) while Oracle uses the opposite (e.g. JSON_EXISTS) - if self.dialect.ON_CONDITION_EMPTY_BEFORE_ERROR: - empty = self._parse_on_handling("EMPTY", *self.ON_CONDITION_TOKENS) - error = self._parse_on_handling("ERROR", *self.ON_CONDITION_TOKENS) - else: - error = self._parse_on_handling("ERROR", *self.ON_CONDITION_TOKENS) - empty = self._parse_on_handling("EMPTY", *self.ON_CONDITION_TOKENS) - - null = self._parse_on_handling("NULL", *self.ON_CONDITION_TOKENS) - - if not empty and not error and not null: - return None - - return self.expression( - exp.OnCondition, - empty=empty, - error=error, - null=null, - ) - - def _parse_on_handling( - self, on: str, *values: str - ) -> t.Optional[str] | t.Optional[exp.Expression]: - # Parses the "X ON Y" or "DEFAULT ON Y syntax, e.g. NULL ON NULL (Oracle, T-SQL, MySQL) - for value in values: - if self._match_text_seq(value, "ON", on): - return f"{value} ON {on}" - - index = self._index - if self._match(TokenType.DEFAULT): - default_value = self._parse_bitwise() - if self._match_text_seq("ON", on): - return default_value - - self._retreat(index) - - return None - - @t.overload - def _parse_json_object(self, agg: Lit[False]) -> exp.JSONObject: ... - - @t.overload - def _parse_json_object(self, agg: Lit[True]) -> exp.JSONObjectAgg: ... - - def _parse_json_object(self, agg=False): - star = self._parse_star() - expressions = ( - [star] - if star - else self._parse_csv(lambda: self._parse_format_json(self._parse_json_key_value())) - ) - null_handling = self._parse_on_handling("NULL", "NULL", "ABSENT") - - unique_keys = None - if self._match_text_seq("WITH", "UNIQUE"): - unique_keys = True - elif self._match_text_seq("WITHOUT", "UNIQUE"): - unique_keys = False - - self._match_text_seq("KEYS") - - return_type = self._match_text_seq("RETURNING") and self._parse_format_json( - self._parse_type() - ) - encoding = self._match_text_seq("ENCODING") and self._parse_var() - - return self.expression( - exp.JSONObjectAgg if agg else exp.JSONObject, - expressions=expressions, - null_handling=null_handling, - unique_keys=unique_keys, - return_type=return_type, - encoding=encoding, - ) - - # Note: this is currently incomplete; it only implements the "JSON_value_column" part - def _parse_json_column_def(self) -> exp.JSONColumnDef: - if not self._match_text_seq("NESTED"): - this = self._parse_id_var() - kind = self._parse_types(allow_identifiers=False) - nested = None - else: - this = None - kind = None - nested = True - - path = self._match_text_seq("PATH") and self._parse_string() - nested_schema = nested and self._parse_json_schema() - - return self.expression( - exp.JSONColumnDef, - this=this, - kind=kind, - path=path, - nested_schema=nested_schema, - ) - - def _parse_json_schema(self) -> exp.JSONSchema: - self._match_text_seq("COLUMNS") - return self.expression( - exp.JSONSchema, - expressions=self._parse_wrapped_csv(self._parse_json_column_def, optional=True), - ) - - def _parse_json_table(self) -> exp.JSONTable: - this = self._parse_format_json(self._parse_bitwise()) - path = self._match(TokenType.COMMA) and self._parse_string() - error_handling = self._parse_on_handling("ERROR", "ERROR", "NULL") - empty_handling = self._parse_on_handling("EMPTY", "ERROR", "NULL") - schema = self._parse_json_schema() - - return exp.JSONTable( - this=this, - schema=schema, - path=path, - error_handling=error_handling, - empty_handling=empty_handling, - ) - - def _parse_match_against(self) -> exp.MatchAgainst: - expressions = self._parse_csv(self._parse_column) - - self._match_text_seq(")", "AGAINST", "(") - - this = self._parse_string() - - if self._match_text_seq("IN", "NATURAL", "LANGUAGE", "MODE"): - modifier = "IN NATURAL LANGUAGE MODE" - if self._match_text_seq("WITH", "QUERY", "EXPANSION"): - modifier = f"{modifier} WITH QUERY EXPANSION" - elif self._match_text_seq("IN", "BOOLEAN", "MODE"): - modifier = "IN BOOLEAN MODE" - elif self._match_text_seq("WITH", "QUERY", "EXPANSION"): - modifier = "WITH QUERY EXPANSION" - else: - modifier = None - - return self.expression( - exp.MatchAgainst, this=this, expressions=expressions, modifier=modifier - ) - - # https://learn.microsoft.com/en-us/sql/t-sql/functions/openjson-transact-sql?view=sql-server-ver16 - def _parse_open_json(self) -> exp.OpenJSON: - this = self._parse_bitwise() - path = self._match(TokenType.COMMA) and self._parse_string() - - def _parse_open_json_column_def() -> exp.OpenJSONColumnDef: - this = self._parse_field(any_token=True) - kind = self._parse_types() - path = self._parse_string() - as_json = self._match_pair(TokenType.ALIAS, TokenType.JSON) - - return self.expression( - exp.OpenJSONColumnDef, this=this, kind=kind, path=path, as_json=as_json - ) - - expressions = None - if self._match_pair(TokenType.R_PAREN, TokenType.WITH): - self._match_l_paren() - expressions = self._parse_csv(_parse_open_json_column_def) - - return self.expression(exp.OpenJSON, this=this, path=path, expressions=expressions) - - def _parse_position(self, haystack_first: bool = False) -> exp.StrPosition: - args = self._parse_csv(self._parse_bitwise) - - if self._match(TokenType.IN): - return self.expression( - exp.StrPosition, this=self._parse_bitwise(), substr=seq_get(args, 0) - ) - - if haystack_first: - haystack = seq_get(args, 0) - needle = seq_get(args, 1) - else: - haystack = seq_get(args, 1) - needle = seq_get(args, 0) - - return self.expression( - exp.StrPosition, this=haystack, substr=needle, position=seq_get(args, 2) - ) - - def _parse_predict(self) -> exp.Predict: - self._match_text_seq("MODEL") - this = self._parse_table() - - self._match(TokenType.COMMA) - self._match_text_seq("TABLE") - - return self.expression( - exp.Predict, - this=this, - expression=self._parse_table(), - params_struct=self._match(TokenType.COMMA) and self._parse_bitwise(), - ) - - def _parse_join_hint(self, func_name: str) -> exp.JoinHint: - args = self._parse_csv(self._parse_table) - return exp.JoinHint(this=func_name.upper(), expressions=args) - - def _parse_substring(self) -> exp.Substring: - # Postgres supports the form: substring(string [from int] [for int]) - # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 - - args = t.cast(t.List[t.Optional[exp.Expression]], self._parse_csv(self._parse_bitwise)) - - if self._match(TokenType.FROM): - args.append(self._parse_bitwise()) - if self._match(TokenType.FOR): - if len(args) == 1: - args.append(exp.Literal.number(1)) - args.append(self._parse_bitwise()) - - return self.validate_expression(exp.Substring.from_arg_list(args), args) - - def _parse_trim(self) -> exp.Trim: - # https://www.w3resource.com/sql/character-functions/trim.php - # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html - - position = None - collation = None - expression = None - - if self._match_texts(self.TRIM_TYPES): - position = self._prev.text.upper() - - this = self._parse_bitwise() - if self._match_set((TokenType.FROM, TokenType.COMMA)): - invert_order = self._prev.token_type == TokenType.FROM or self.TRIM_PATTERN_FIRST - expression = self._parse_bitwise() - - if invert_order: - this, expression = expression, this - - if self._match(TokenType.COLLATE): - collation = self._parse_bitwise() - - return self.expression( - exp.Trim, this=this, position=position, expression=expression, collation=collation - ) - - def _parse_window_clause(self) -> t.Optional[t.List[exp.Expression]]: - return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window) - - def _parse_named_window(self) -> t.Optional[exp.Expression]: - return self._parse_window(self._parse_id_var(), alias=True) - - def _parse_respect_or_ignore_nulls( - self, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Expression]: - if self._match_text_seq("IGNORE", "NULLS"): - return self.expression(exp.IgnoreNulls, this=this) - if self._match_text_seq("RESPECT", "NULLS"): - return self.expression(exp.RespectNulls, this=this) - return this - - def _parse_having_max(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: - if self._match(TokenType.HAVING): - self._match_texts(("MAX", "MIN")) - max = self._prev.text.upper() != "MIN" - return self.expression( - exp.HavingMax, this=this, expression=self._parse_column(), max=max - ) - - return this - - def _parse_window( - self, this: t.Optional[exp.Expression], alias: bool = False - ) -> t.Optional[exp.Expression]: - func = this - comments = func.comments if isinstance(func, exp.Expression) else None - - # T-SQL allows the OVER (...) syntax after WITHIN GROUP. - # https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16 - if self._match_text_seq("WITHIN", "GROUP"): - order = self._parse_wrapped(self._parse_order) - this = self.expression(exp.WithinGroup, this=this, expression=order) - - if self._match_pair(TokenType.FILTER, TokenType.L_PAREN): - self._match(TokenType.WHERE) - this = self.expression( - exp.Filter, this=this, expression=self._parse_where(skip_where_token=True) - ) - self._match_r_paren() - - # SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER - # Some dialects choose to implement and some do not. - # https://dev.mysql.com/doc/refman/8.0/en/window-function-descriptions.html - - # There is some code above in _parse_lambda that handles - # SELECT FIRST_VALUE(TABLE.COLUMN IGNORE|RESPECT NULLS) OVER ... - - # The below changes handle - # SELECT FIRST_VALUE(TABLE.COLUMN) IGNORE|RESPECT NULLS OVER ... - - # Oracle allows both formats - # (https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/img_text/first_value.html) - # and Snowflake chose to do the same for familiarity - # https://docs.snowflake.com/en/sql-reference/functions/first_value.html#usage-notes - if isinstance(this, exp.AggFunc): - ignore_respect = this.find(exp.IgnoreNulls, exp.RespectNulls) - - if ignore_respect and ignore_respect is not this: - ignore_respect.replace(ignore_respect.this) - this = self.expression(ignore_respect.__class__, this=this) - - this = self._parse_respect_or_ignore_nulls(this) - - # bigquery select from window x AS (partition by ...) - if alias: - over = None - self._match(TokenType.ALIAS) - elif not self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS): - return this - else: - over = self._prev.text.upper() - - if comments and isinstance(func, exp.Expression): - func.pop_comments() - - if not self._match(TokenType.L_PAREN): - return self.expression( - exp.Window, - comments=comments, - this=this, - alias=self._parse_id_var(False), - over=over, - ) - - window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS) - - first = self._match(TokenType.FIRST) - if self._match_text_seq("LAST"): - first = False - - partition, order = self._parse_partition_and_order() - kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text - - if kind: - self._match(TokenType.BETWEEN) - start = self._parse_window_spec() - self._match(TokenType.AND) - end = self._parse_window_spec() - exclude = ( - self._parse_var_from_options(self.WINDOW_EXCLUDE_OPTIONS) - if self._match_text_seq("EXCLUDE") - else None - ) - - spec = self.expression( - exp.WindowSpec, - kind=kind, - start=start["value"], - start_side=start["side"], - end=end["value"], - end_side=end["side"], - exclude=exclude, - ) - else: - spec = None - - self._match_r_paren() - - window = self.expression( - exp.Window, - comments=comments, - this=this, - partition_by=partition, - order=order, - spec=spec, - alias=window_alias, - over=over, - first=first, - ) - - # This covers Oracle's FIRST/LAST syntax: aggregate KEEP (...) OVER (...) - if self._match_set(self.WINDOW_BEFORE_PAREN_TOKENS, advance=False): - return self._parse_window(window, alias=alias) - - return window - - def _parse_partition_and_order( - self, - ) -> t.Tuple[t.List[exp.Expression], t.Optional[exp.Expression]]: - return self._parse_partition_by(), self._parse_order() - - def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]: - self._match(TokenType.BETWEEN) - - return { - "value": ( - (self._match_text_seq("UNBOUNDED") and "UNBOUNDED") - or (self._match_text_seq("CURRENT", "ROW") and "CURRENT ROW") - or self._parse_bitwise() - ), - "side": self._match_texts(self.WINDOW_SIDES) and self._prev.text, - } - - def _parse_alias( - self, this: t.Optional[exp.Expression], explicit: bool = False - ) -> t.Optional[exp.Expression]: - # In some dialects, LIMIT and OFFSET can act as both identifiers and keywords (clauses) - # so this section tries to parse the clause version and if it fails, it treats the token - # as an identifier (alias) - if self._can_parse_limit_or_offset(): - return this - - any_token = self._match(TokenType.ALIAS) - comments = self._prev_comments or [] - - if explicit and not any_token: - return this - - if self._match(TokenType.L_PAREN): - aliases = self.expression( - exp.Aliases, - comments=comments, - this=this, - expressions=self._parse_csv(lambda: self._parse_id_var(any_token)), - ) - self._match_r_paren(aliases) - return aliases - - alias = self._parse_id_var(any_token, tokens=self.ALIAS_TOKENS) or ( - self.STRING_ALIASES and self._parse_string_as_identifier() - ) - - if alias: - comments.extend(alias.pop_comments()) - this = self.expression(exp.Alias, comments=comments, this=this, alias=alias) - column = this.this - - # Moves the comment next to the alias in `expr /* comment */ AS alias` - if not this.comments and column and column.comments: - this.comments = column.pop_comments() - - return this - - def _parse_id_var( - self, - any_token: bool = True, - tokens: t.Optional[t.Collection[TokenType]] = None, - ) -> t.Optional[exp.Expression]: - expression = self._parse_identifier() - if not expression and ( - (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS) - ): - quoted = self._prev.token_type == TokenType.STRING - expression = self._identifier_expression(quoted=quoted) - - return expression - - def _parse_string(self) -> t.Optional[exp.Expression]: - if self._match_set(self.STRING_PARSERS): - return self.STRING_PARSERS[self._prev.token_type](self, self._prev) - return self._parse_placeholder() - - def _parse_string_as_identifier(self) -> t.Optional[exp.Identifier]: - output = exp.to_identifier(self._match(TokenType.STRING) and self._prev.text, quoted=True) - if output: - output.update_positions(self._prev) - return output - - def _parse_number(self) -> t.Optional[exp.Expression]: - if self._match_set(self.NUMERIC_PARSERS): - return self.NUMERIC_PARSERS[self._prev.token_type](self, self._prev) - return self._parse_placeholder() - - def _parse_identifier(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.IDENTIFIER): - return self._identifier_expression(quoted=True) - return self._parse_placeholder() - - def _parse_var( - self, - any_token: bool = False, - tokens: t.Optional[t.Collection[TokenType]] = None, - upper: bool = False, - ) -> t.Optional[exp.Expression]: - if ( - (any_token and self._advance_any()) - or self._match(TokenType.VAR) - or (self._match_set(tokens) if tokens else False) - ): - return self.expression( - exp.Var, this=self._prev.text.upper() if upper else self._prev.text - ) - return self._parse_placeholder() - - def _advance_any(self, ignore_reserved: bool = False) -> t.Optional[Token]: - if self._curr and (ignore_reserved or self._curr.token_type not in self.RESERVED_TOKENS): - self._advance() - return self._prev - return None - - def _parse_var_or_string(self, upper: bool = False) -> t.Optional[exp.Expression]: - return self._parse_string() or self._parse_var(any_token=True, upper=upper) - - def _parse_primary_or_var(self) -> t.Optional[exp.Expression]: - return self._parse_primary() or self._parse_var(any_token=True) - - def _parse_null(self) -> t.Optional[exp.Expression]: - if self._match_set(self.NULL_TOKENS): - return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev) - return self._parse_placeholder() - - def _parse_boolean(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.TRUE): - return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev) - if self._match(TokenType.FALSE): - return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev) - return self._parse_placeholder() - - def _parse_star(self) -> t.Optional[exp.Expression]: - if self._match(TokenType.STAR): - return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) - return self._parse_placeholder() - - def _parse_parameter(self) -> exp.Parameter: - this = self._parse_identifier() or self._parse_primary_or_var() - return self.expression(exp.Parameter, this=this) - - def _parse_placeholder(self) -> t.Optional[exp.Expression]: - if self._match_set(self.PLACEHOLDER_PARSERS): - placeholder = self.PLACEHOLDER_PARSERS[self._prev.token_type](self) - if placeholder: - return placeholder - self._advance(-1) - return None - - def _parse_star_op(self, *keywords: str) -> t.Optional[t.List[exp.Expression]]: - if not self._match_texts(keywords): - return None - if self._match(TokenType.L_PAREN, advance=False): - return self._parse_wrapped_csv(self._parse_expression) - - expression = self._parse_expression() - return [expression] if expression else None - - def _parse_csv( - self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA - ) -> t.List[exp.Expression]: - parse_result = parse_method() - items = [parse_result] if parse_result is not None else [] - - while self._match(sep): - self._add_comments(parse_result) - parse_result = parse_method() - if parse_result is not None: - items.append(parse_result) - - return items - - def _parse_tokens( - self, parse_method: t.Callable, expressions: t.Dict - ) -> t.Optional[exp.Expression]: - this = parse_method() - - while self._match_set(expressions): - this = self.expression( - expressions[self._prev.token_type], - this=this, - comments=self._prev_comments, - expression=parse_method(), - ) - - return this - - def _parse_wrapped_id_vars(self, optional: bool = False) -> t.List[exp.Expression]: - return self._parse_wrapped_csv(self._parse_id_var, optional=optional) - - def _parse_wrapped_csv( - self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA, optional: bool = False - ) -> t.List[exp.Expression]: - return self._parse_wrapped( - lambda: self._parse_csv(parse_method, sep=sep), optional=optional - ) - - def _parse_wrapped(self, parse_method: t.Callable, optional: bool = False) -> t.Any: - wrapped = self._match(TokenType.L_PAREN) - if not wrapped and not optional: - self.raise_error("Expecting (") - parse_result = parse_method() - if wrapped: - self._match_r_paren() - return parse_result - - def _parse_expressions(self) -> t.List[exp.Expression]: - return self._parse_csv(self._parse_expression) - - def _parse_select_or_expression(self, alias: bool = False) -> t.Optional[exp.Expression]: - return self._parse_select() or self._parse_set_operations( - self._parse_alias(self._parse_assignment(), explicit=True) - if alias - else self._parse_assignment() - ) - - def _parse_ddl_select(self) -> t.Optional[exp.Expression]: - return self._parse_query_modifiers( - self._parse_set_operations(self._parse_select(nested=True, parse_subquery_alias=False)) - ) - - def _parse_transaction(self) -> exp.Transaction | exp.Command: - this = None - if self._match_texts(self.TRANSACTION_KIND): - this = self._prev.text - - self._match_texts(("TRANSACTION", "WORK")) - - modes = [] - while True: - mode = [] - while self._match(TokenType.VAR): - mode.append(self._prev.text) - - if mode: - modes.append(" ".join(mode)) - if not self._match(TokenType.COMMA): - break - - return self.expression(exp.Transaction, this=this, modes=modes) - - def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: - chain = None - savepoint = None - is_rollback = self._prev.token_type == TokenType.ROLLBACK - - self._match_texts(("TRANSACTION", "WORK")) - - if self._match_text_seq("TO"): - self._match_text_seq("SAVEPOINT") - savepoint = self._parse_id_var() - - if self._match(TokenType.AND): - chain = not self._match_text_seq("NO") - self._match_text_seq("CHAIN") - - if is_rollback: - return self.expression(exp.Rollback, savepoint=savepoint) - - return self.expression(exp.Commit, chain=chain) - - def _parse_refresh(self) -> exp.Refresh: - self._match(TokenType.TABLE) - return self.expression(exp.Refresh, this=self._parse_string() or self._parse_table()) - - def _parse_add_column(self) -> t.Optional[exp.ColumnDef]: - if not self._prev.text.upper() == "ADD": - return None - - start = self._index - self._match(TokenType.COLUMN) - - exists_column = self._parse_exists(not_=True) - expression = self._parse_field_def() - - if not isinstance(expression, exp.ColumnDef): - self._retreat(start) - return None - - expression.set("exists", exists_column) - - # https://docs.databricks.com/delta/update-schema.html#explicitly-update-schema-to-add-columns - if self._match_texts(("FIRST", "AFTER")): - position = self._prev.text - column_position = self.expression( - exp.ColumnPosition, this=self._parse_column(), position=position - ) - expression.set("position", column_position) - - return expression - - def _parse_drop_column(self) -> t.Optional[exp.Drop | exp.Command]: - drop = self._match(TokenType.DROP) and self._parse_drop() - if drop and not isinstance(drop, exp.Command): - drop.set("kind", drop.args.get("kind", "COLUMN")) - return drop - - # https://docs.aws.amazon.com/athena/latest/ug/alter-table-drop-partition.html - def _parse_drop_partition(self, exists: t.Optional[bool] = None) -> exp.DropPartition: - return self.expression( - exp.DropPartition, expressions=self._parse_csv(self._parse_partition), exists=exists - ) - - def _parse_alter_table_add(self) -> t.List[exp.Expression]: - def _parse_add_alteration() -> t.Optional[exp.Expression]: - self._match_text_seq("ADD") - if self._match_set(self.ADD_CONSTRAINT_TOKENS, advance=False): - return self.expression( - exp.AddConstraint, expressions=self._parse_csv(self._parse_constraint) - ) - - column_def = self._parse_add_column() - if isinstance(column_def, exp.ColumnDef): - return column_def - - exists = self._parse_exists(not_=True) - if self._match_pair(TokenType.PARTITION, TokenType.L_PAREN, advance=False): - return self.expression( - exp.AddPartition, exists=exists, this=self._parse_field(any_token=True) - ) - - return None - - if not self.dialect.ALTER_TABLE_ADD_REQUIRED_FOR_EACH_COLUMN or self._match_text_seq( - "COLUMNS" - ): - schema = self._parse_schema() - - return ensure_list(schema) if schema else self._parse_csv(self._parse_field_def) - - return self._parse_csv(_parse_add_alteration) - - def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]: - if self._match_texts(self.ALTER_ALTER_PARSERS): - return self.ALTER_ALTER_PARSERS[self._prev.text.upper()](self) - - # Many dialects support the ALTER [COLUMN] syntax, so if there is no - # keyword after ALTER we default to parsing this statement - self._match(TokenType.COLUMN) - column = self._parse_field(any_token=True) - - if self._match_pair(TokenType.DROP, TokenType.DEFAULT): - return self.expression(exp.AlterColumn, this=column, drop=True) - if self._match_pair(TokenType.SET, TokenType.DEFAULT): - return self.expression(exp.AlterColumn, this=column, default=self._parse_assignment()) - if self._match(TokenType.COMMENT): - return self.expression(exp.AlterColumn, this=column, comment=self._parse_string()) - if self._match_text_seq("DROP", "NOT", "NULL"): - return self.expression( - exp.AlterColumn, - this=column, - drop=True, - allow_null=True, - ) - if self._match_text_seq("SET", "NOT", "NULL"): - return self.expression( - exp.AlterColumn, - this=column, - allow_null=False, - ) - - if self._match_text_seq("SET", "VISIBLE"): - return self.expression(exp.AlterColumn, this=column, visible="VISIBLE") - if self._match_text_seq("SET", "INVISIBLE"): - return self.expression(exp.AlterColumn, this=column, visible="INVISIBLE") - - self._match_text_seq("SET", "DATA") - self._match_text_seq("TYPE") - return self.expression( - exp.AlterColumn, - this=column, - dtype=self._parse_types(), - collate=self._match(TokenType.COLLATE) and self._parse_term(), - using=self._match(TokenType.USING) and self._parse_assignment(), - ) - - def _parse_alter_diststyle(self) -> exp.AlterDistStyle: - if self._match_texts(("ALL", "EVEN", "AUTO")): - return self.expression(exp.AlterDistStyle, this=exp.var(self._prev.text.upper())) - - self._match_text_seq("KEY", "DISTKEY") - return self.expression(exp.AlterDistStyle, this=self._parse_column()) - - def _parse_alter_sortkey(self, compound: t.Optional[bool] = None) -> exp.AlterSortKey: - if compound: - self._match_text_seq("SORTKEY") - - if self._match(TokenType.L_PAREN, advance=False): - return self.expression( - exp.AlterSortKey, expressions=self._parse_wrapped_id_vars(), compound=compound - ) - - self._match_texts(("AUTO", "NONE")) - return self.expression( - exp.AlterSortKey, this=exp.var(self._prev.text.upper()), compound=compound - ) - - def _parse_alter_table_drop(self) -> t.List[exp.Expression]: - index = self._index - 1 - - partition_exists = self._parse_exists() - if self._match(TokenType.PARTITION, advance=False): - return self._parse_csv(lambda: self._parse_drop_partition(exists=partition_exists)) - - self._retreat(index) - return self._parse_csv(self._parse_drop_column) - - def _parse_alter_table_rename(self) -> t.Optional[exp.AlterRename | exp.RenameColumn]: - if self._match(TokenType.COLUMN) or not self.ALTER_RENAME_REQUIRES_COLUMN: - exists = self._parse_exists() - old_column = self._parse_column() - to = self._match_text_seq("TO") - new_column = self._parse_column() - - if old_column is None or to is None or new_column is None: - return None - - return self.expression(exp.RenameColumn, this=old_column, to=new_column, exists=exists) - - self._match_text_seq("TO") - return self.expression(exp.AlterRename, this=self._parse_table(schema=True)) - - def _parse_alter_table_set(self) -> exp.AlterSet: - alter_set = self.expression(exp.AlterSet) - - if self._match(TokenType.L_PAREN, advance=False) or self._match_text_seq( - "TABLE", "PROPERTIES" - ): - alter_set.set("expressions", self._parse_wrapped_csv(self._parse_assignment)) - elif self._match_text_seq("FILESTREAM_ON", advance=False): - alter_set.set("expressions", [self._parse_assignment()]) - elif self._match_texts(("LOGGED", "UNLOGGED")): - alter_set.set("option", exp.var(self._prev.text.upper())) - elif self._match_text_seq("WITHOUT") and self._match_texts(("CLUSTER", "OIDS")): - alter_set.set("option", exp.var(f"WITHOUT {self._prev.text.upper()}")) - elif self._match_text_seq("LOCATION"): - alter_set.set("location", self._parse_field()) - elif self._match_text_seq("ACCESS", "METHOD"): - alter_set.set("access_method", self._parse_field()) - elif self._match_text_seq("TABLESPACE"): - alter_set.set("tablespace", self._parse_field()) - elif self._match_text_seq("FILE", "FORMAT") or self._match_text_seq("FILEFORMAT"): - alter_set.set("file_format", [self._parse_field()]) - elif self._match_text_seq("STAGE_FILE_FORMAT"): - alter_set.set("file_format", self._parse_wrapped_options()) - elif self._match_text_seq("STAGE_COPY_OPTIONS"): - alter_set.set("copy_options", self._parse_wrapped_options()) - elif self._match_text_seq("TAG") or self._match_text_seq("TAGS"): - alter_set.set("tag", self._parse_csv(self._parse_assignment)) - else: - if self._match_text_seq("SERDE"): - alter_set.set("serde", self._parse_field()) - - properties = self._parse_wrapped(self._parse_properties, optional=True) - alter_set.set("expressions", [properties]) - - return alter_set - - def _parse_alter(self) -> exp.Alter | exp.Command: - start = self._prev - - alter_token = self._match_set(self.ALTERABLES) and self._prev - if not alter_token: - return self._parse_as_command(start) - - exists = self._parse_exists() - only = self._match_text_seq("ONLY") - this = self._parse_table(schema=True) - cluster = self._parse_on_property() if self._match(TokenType.ON) else None - - if self._next: - self._advance() - - parser = self.ALTER_PARSERS.get(self._prev.text.upper()) if self._prev else None - if parser: - actions = ensure_list(parser(self)) - not_valid = self._match_text_seq("NOT", "VALID") - options = self._parse_csv(self._parse_property) - - if not self._curr and actions: - return self.expression( - exp.Alter, - this=this, - kind=alter_token.text.upper(), - exists=exists, - actions=actions, - only=only, - options=options, - cluster=cluster, - not_valid=not_valid, - ) - - return self._parse_as_command(start) - - def _parse_analyze(self) -> exp.Analyze | exp.Command: - start = self._prev - # https://duckdb.org/docs/sql/statements/analyze - if not self._curr: - return self.expression(exp.Analyze) - - options = [] - while self._match_texts(self.ANALYZE_STYLES): - if self._prev.text.upper() == "BUFFER_USAGE_LIMIT": - options.append(f"BUFFER_USAGE_LIMIT {self._parse_number()}") - else: - options.append(self._prev.text.upper()) - - this: t.Optional[exp.Expression] = None - inner_expression: t.Optional[exp.Expression] = None - - kind = self._curr and self._curr.text.upper() - - if self._match(TokenType.TABLE) or self._match(TokenType.INDEX): - this = self._parse_table_parts() - elif self._match_text_seq("TABLES"): - if self._match_set((TokenType.FROM, TokenType.IN)): - kind = f"{kind} {self._prev.text.upper()}" - this = self._parse_table(schema=True, is_db_reference=True) - elif self._match_text_seq("DATABASE"): - this = self._parse_table(schema=True, is_db_reference=True) - elif self._match_text_seq("CLUSTER"): - this = self._parse_table() - # Try matching inner expr keywords before fallback to parse table. - elif self._match_texts(self.ANALYZE_EXPRESSION_PARSERS): - kind = None - inner_expression = self.ANALYZE_EXPRESSION_PARSERS[self._prev.text.upper()](self) - else: - # Empty kind https://prestodb.io/docs/current/sql/analyze.html - kind = None - this = self._parse_table_parts() - - partition = self._try_parse(self._parse_partition) - if not partition and self._match_texts(self.PARTITION_KEYWORDS): - return self._parse_as_command(start) - - # https://docs.starrocks.io/docs/sql-reference/sql-statements/cbo_stats/ANALYZE_TABLE/ - if self._match_text_seq("WITH", "SYNC", "MODE") or self._match_text_seq( - "WITH", "ASYNC", "MODE" - ): - mode = f"WITH {self._tokens[self._index - 2].text.upper()} MODE" - else: - mode = None - - if self._match_texts(self.ANALYZE_EXPRESSION_PARSERS): - inner_expression = self.ANALYZE_EXPRESSION_PARSERS[self._prev.text.upper()](self) - - properties = self._parse_properties() - return self.expression( - exp.Analyze, - kind=kind, - this=this, - mode=mode, - partition=partition, - properties=properties, - expression=inner_expression, - options=options, - ) - - # https://spark.apache.org/docs/3.5.1/sql-ref-syntax-aux-analyze-table.html - def _parse_analyze_statistics(self) -> exp.AnalyzeStatistics: - this = None - kind = self._prev.text.upper() - option = self._prev.text.upper() if self._match_text_seq("DELTA") else None - expressions = [] - - if not self._match_text_seq("STATISTICS"): - self.raise_error("Expecting token STATISTICS") - - if self._match_text_seq("NOSCAN"): - this = "NOSCAN" - elif self._match(TokenType.FOR): - if self._match_text_seq("ALL", "COLUMNS"): - this = "FOR ALL COLUMNS" - if self._match_texts("COLUMNS"): - this = "FOR COLUMNS" - expressions = self._parse_csv(self._parse_column_reference) - elif self._match_text_seq("SAMPLE"): - sample = self._parse_number() - expressions = [ - self.expression( - exp.AnalyzeSample, - sample=sample, - kind=self._prev.text.upper() if self._match(TokenType.PERCENT) else None, - ) - ] - - return self.expression( - exp.AnalyzeStatistics, kind=kind, option=option, this=this, expressions=expressions - ) - - # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ANALYZE.html - def _parse_analyze_validate(self) -> exp.AnalyzeValidate: - kind = None - this = None - expression: t.Optional[exp.Expression] = None - if self._match_text_seq("REF", "UPDATE"): - kind = "REF" - this = "UPDATE" - if self._match_text_seq("SET", "DANGLING", "TO", "NULL"): - this = "UPDATE SET DANGLING TO NULL" - elif self._match_text_seq("STRUCTURE"): - kind = "STRUCTURE" - if self._match_text_seq("CASCADE", "FAST"): - this = "CASCADE FAST" - elif self._match_text_seq("CASCADE", "COMPLETE") and self._match_texts( - ("ONLINE", "OFFLINE") - ): - this = f"CASCADE COMPLETE {self._prev.text.upper()}" - expression = self._parse_into() - - return self.expression(exp.AnalyzeValidate, kind=kind, this=this, expression=expression) - - def _parse_analyze_columns(self) -> t.Optional[exp.AnalyzeColumns]: - this = self._prev.text.upper() - if self._match_text_seq("COLUMNS"): - return self.expression(exp.AnalyzeColumns, this=f"{this} {self._prev.text.upper()}") - return None - - def _parse_analyze_delete(self) -> t.Optional[exp.AnalyzeDelete]: - kind = self._prev.text.upper() if self._match_text_seq("SYSTEM") else None - if self._match_text_seq("STATISTICS"): - return self.expression(exp.AnalyzeDelete, kind=kind) - return None - - def _parse_analyze_list(self) -> t.Optional[exp.AnalyzeListChainedRows]: - if self._match_text_seq("CHAINED", "ROWS"): - return self.expression(exp.AnalyzeListChainedRows, expression=self._parse_into()) - return None - - # https://dev.mysql.com/doc/refman/8.4/en/analyze-table.html - def _parse_analyze_histogram(self) -> exp.AnalyzeHistogram: - this = self._prev.text.upper() - expression: t.Optional[exp.Expression] = None - expressions = [] - update_options = None - - if self._match_text_seq("HISTOGRAM", "ON"): - expressions = self._parse_csv(self._parse_column_reference) - with_expressions = [] - while self._match(TokenType.WITH): - # https://docs.starrocks.io/docs/sql-reference/sql-statements/cbo_stats/ANALYZE_TABLE/ - if self._match_texts(("SYNC", "ASYNC")): - if self._match_text_seq("MODE", advance=False): - with_expressions.append(f"{self._prev.text.upper()} MODE") - self._advance() - else: - buckets = self._parse_number() - if self._match_text_seq("BUCKETS"): - with_expressions.append(f"{buckets} BUCKETS") - if with_expressions: - expression = self.expression(exp.AnalyzeWith, expressions=with_expressions) - - if self._match_texts(("MANUAL", "AUTO")) and self._match( - TokenType.UPDATE, advance=False - ): - update_options = self._prev.text.upper() - self._advance() - elif self._match_text_seq("USING", "DATA"): - expression = self.expression(exp.UsingData, this=self._parse_string()) - - return self.expression( - exp.AnalyzeHistogram, - this=this, - expressions=expressions, - expression=expression, - update_options=update_options, - ) - - def _parse_merge(self) -> exp.Merge: - self._match(TokenType.INTO) - target = self._parse_table() - - if target and self._match(TokenType.ALIAS, advance=False): - target.set("alias", self._parse_table_alias()) - - self._match(TokenType.USING) - using = self._parse_table() - - self._match(TokenType.ON) - on = self._parse_assignment() - - return self.expression( - exp.Merge, - this=target, - using=using, - on=on, - whens=self._parse_when_matched(), - returning=self._parse_returning(), - ) - - def _parse_when_matched(self) -> exp.Whens: - whens = [] - - while self._match(TokenType.WHEN): - matched = not self._match(TokenType.NOT) - self._match_text_seq("MATCHED") - source = ( - False - if self._match_text_seq("BY", "TARGET") - else self._match_text_seq("BY", "SOURCE") - ) - condition = self._parse_assignment() if self._match(TokenType.AND) else None - - self._match(TokenType.THEN) - - if self._match(TokenType.INSERT): - this = self._parse_star() - if this: - then: t.Optional[exp.Expression] = self.expression(exp.Insert, this=this) - else: - then = self.expression( - exp.Insert, - this=exp.var("ROW") - if self._match_text_seq("ROW") - else self._parse_value(values=False), - expression=self._match_text_seq("VALUES") and self._parse_value(), - ) - elif self._match(TokenType.UPDATE): - expressions = self._parse_star() - if expressions: - then = self.expression(exp.Update, expressions=expressions) - else: - then = self.expression( - exp.Update, - expressions=self._match(TokenType.SET) - and self._parse_csv(self._parse_equality), - ) - elif self._match(TokenType.DELETE): - then = self.expression(exp.Var, this=self._prev.text) - else: - then = self._parse_var_from_options(self.CONFLICT_ACTIONS) - - whens.append( - self.expression( - exp.When, - matched=matched, - source=source, - condition=condition, - then=then, - ) - ) - return self.expression(exp.Whens, expressions=whens) - - def _parse_show(self) -> t.Optional[exp.Expression]: - parser = self._find_parser(self.SHOW_PARSERS, self.SHOW_TRIE) - if parser: - return parser(self) - return self._parse_as_command(self._prev) - - def _parse_set_item_assignment( - self, kind: t.Optional[str] = None - ) -> t.Optional[exp.Expression]: - index = self._index - - if kind in ("GLOBAL", "SESSION") and self._match_text_seq("TRANSACTION"): - return self._parse_set_transaction(global_=kind == "GLOBAL") - - left = self._parse_primary() or self._parse_column() - assignment_delimiter = self._match_texts(("=", "TO")) - - if not left or (self.SET_REQUIRES_ASSIGNMENT_DELIMITER and not assignment_delimiter): - self._retreat(index) - return None - - right = self._parse_statement() or self._parse_id_var() - if isinstance(right, (exp.Column, exp.Identifier)): - right = exp.var(right.name) - - this = self.expression(exp.EQ, this=left, expression=right) - return self.expression(exp.SetItem, this=this, kind=kind) - - def _parse_set_transaction(self, global_: bool = False) -> exp.Expression: - self._match_text_seq("TRANSACTION") - characteristics = self._parse_csv( - lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS) - ) - return self.expression( - exp.SetItem, - expressions=characteristics, - kind="TRANSACTION", - **{"global": global_}, # type: ignore - ) - - def _parse_set_item(self) -> t.Optional[exp.Expression]: - parser = self._find_parser(self.SET_PARSERS, self.SET_TRIE) - return parser(self) if parser else self._parse_set_item_assignment(kind=None) - - def _parse_set(self, unset: bool = False, tag: bool = False) -> exp.Set | exp.Command: - index = self._index - set_ = self.expression( - exp.Set, expressions=self._parse_csv(self._parse_set_item), unset=unset, tag=tag - ) - - if self._curr: - self._retreat(index) - return self._parse_as_command(self._prev) - - return set_ - - def _parse_var_from_options( - self, options: OPTIONS_TYPE, raise_unmatched: bool = True - ) -> t.Optional[exp.Var]: - start = self._curr - if not start: - return None - - option = start.text.upper() - continuations = options.get(option) - - index = self._index - self._advance() - for keywords in continuations or []: - if isinstance(keywords, str): - keywords = (keywords,) - - if self._match_text_seq(*keywords): - option = f"{option} {' '.join(keywords)}" - break - else: - if continuations or continuations is None: - if raise_unmatched: - self.raise_error(f"Unknown option {option}") - - self._retreat(index) - return None - - return exp.var(option) - - def _parse_as_command(self, start: Token) -> exp.Command: - while self._curr: - self._advance() - text = self._find_sql(start, self._prev) - size = len(start.text) - self._warn_unsupported() - return exp.Command(this=text[:size], expression=text[size:]) - - def _parse_dict_property(self, this: str) -> exp.DictProperty: - settings = [] - - self._match_l_paren() - kind = self._parse_id_var() - - if self._match(TokenType.L_PAREN): - while True: - key = self._parse_id_var() - value = self._parse_primary() - if not key and value is None: - break - settings.append(self.expression(exp.DictSubProperty, this=key, value=value)) - self._match(TokenType.R_PAREN) - - self._match_r_paren() - - return self.expression( - exp.DictProperty, - this=this, - kind=kind.this if kind else None, - settings=settings, - ) - - def _parse_dict_range(self, this: str) -> exp.DictRange: - self._match_l_paren() - has_min = self._match_text_seq("MIN") - if has_min: - min = self._parse_var() or self._parse_primary() - self._match_text_seq("MAX") - max = self._parse_var() or self._parse_primary() - else: - max = self._parse_var() or self._parse_primary() - min = exp.Literal.number(0) - self._match_r_paren() - return self.expression(exp.DictRange, this=this, min=min, max=max) - - def _parse_comprehension( - self, this: t.Optional[exp.Expression] - ) -> t.Optional[exp.Comprehension]: - index = self._index - expression = self._parse_column() - if not self._match(TokenType.IN): - self._retreat(index - 1) - return None - iterator = self._parse_column() - condition = self._parse_assignment() if self._match_text_seq("IF") else None - return self.expression( - exp.Comprehension, - this=this, - expression=expression, - iterator=iterator, - condition=condition, - ) - - def _parse_heredoc(self) -> t.Optional[exp.Heredoc]: - if self._match(TokenType.HEREDOC_STRING): - return self.expression(exp.Heredoc, this=self._prev.text) - - if not self._match_text_seq("$"): - return None - - tags = ["$"] - tag_text = None - - if self._is_connected(): - self._advance() - tags.append(self._prev.text.upper()) - else: - self.raise_error("No closing $ found") - - if tags[-1] != "$": - if self._is_connected() and self._match_text_seq("$"): - tag_text = tags[-1] - tags.append("$") - else: - self.raise_error("No closing $ found") - - heredoc_start = self._curr - - while self._curr: - if self._match_text_seq(*tags, advance=False): - this = self._find_sql(heredoc_start, self._prev) - self._advance(len(tags)) - return self.expression(exp.Heredoc, this=this, tag=tag_text) - - self._advance() - - self.raise_error(f"No closing {''.join(tags)} found") - return None - - def _find_parser( - self, parsers: t.Dict[str, t.Callable], trie: t.Dict - ) -> t.Optional[t.Callable]: - if not self._curr: - return None - - index = self._index - this = [] - while True: - # The current token might be multiple words - curr = self._curr.text.upper() - key = curr.split(" ") - this.append(curr) - - self._advance() - result, trie = in_trie(trie, key) - if result == TrieResult.FAILED: - break - - if result == TrieResult.EXISTS: - subparser = parsers[" ".join(this)] - return subparser - - self._retreat(index) - return None - - def _match(self, token_type, advance=True, expression=None): - if not self._curr: - return None - - if self._curr.token_type == token_type: - if advance: - self._advance() - self._add_comments(expression) - return True - - return None - - def _match_set(self, types, advance=True): - if not self._curr: - return None - - if self._curr.token_type in types: - if advance: - self._advance() - return True - - return None - - def _match_pair(self, token_type_a, token_type_b, advance=True): - if not self._curr or not self._next: - return None - - if self._curr.token_type == token_type_a and self._next.token_type == token_type_b: - if advance: - self._advance(2) - return True - - return None - - def _match_l_paren(self, expression: t.Optional[exp.Expression] = None) -> None: - if not self._match(TokenType.L_PAREN, expression=expression): - self.raise_error("Expecting (") - - def _match_r_paren(self, expression: t.Optional[exp.Expression] = None) -> None: - if not self._match(TokenType.R_PAREN, expression=expression): - self.raise_error("Expecting )") - - def _match_texts(self, texts, advance=True): - if ( - self._curr - and self._curr.token_type != TokenType.STRING - and self._curr.text.upper() in texts - ): - if advance: - self._advance() - return True - return None - - def _match_text_seq(self, *texts, advance=True): - index = self._index - for text in texts: - if ( - self._curr - and self._curr.token_type != TokenType.STRING - and self._curr.text.upper() == text - ): - self._advance() - else: - self._retreat(index) - return None - - if not advance: - self._retreat(index) - - return True - - def _replace_lambda( - self, node: t.Optional[exp.Expression], expressions: t.List[exp.Expression] - ) -> t.Optional[exp.Expression]: - if not node: - return node - - lambda_types = {e.name: e.args.get("to") or False for e in expressions} - - for column in node.find_all(exp.Column): - typ = lambda_types.get(column.parts[0].name) - if typ is not None: - dot_or_id = column.to_dot() if column.table else column.this - - if typ: - dot_or_id = self.expression( - exp.Cast, - this=dot_or_id, - to=typ, - ) - - parent = column.parent - - while isinstance(parent, exp.Dot): - if not isinstance(parent.parent, exp.Dot): - parent.replace(dot_or_id) - break - parent = parent.parent - else: - if column is node: - node = dot_or_id - else: - column.replace(dot_or_id) - return node - - def _parse_truncate_table(self) -> t.Optional[exp.TruncateTable] | exp.Expression: - start = self._prev - - # Not to be confused with TRUNCATE(number, decimals) function call - if self._match(TokenType.L_PAREN): - self._retreat(self._index - 2) - return self._parse_function() - - # Clickhouse supports TRUNCATE DATABASE as well - is_database = self._match(TokenType.DATABASE) - - self._match(TokenType.TABLE) - - exists = self._parse_exists(not_=False) - - expressions = self._parse_csv( - lambda: self._parse_table(schema=True, is_db_reference=is_database) - ) - - cluster = self._parse_on_property() if self._match(TokenType.ON) else None - - if self._match_text_seq("RESTART", "IDENTITY"): - identity = "RESTART" - elif self._match_text_seq("CONTINUE", "IDENTITY"): - identity = "CONTINUE" - else: - identity = None - - if self._match_text_seq("CASCADE") or self._match_text_seq("RESTRICT"): - option = self._prev.text - else: - option = None - - partition = self._parse_partition() - - # Fallback case - if self._curr: - return self._parse_as_command(start) - - return self.expression( - exp.TruncateTable, - expressions=expressions, - is_database=is_database, - exists=exists, - cluster=cluster, - identity=identity, - option=option, - partition=partition, - ) - - def _parse_with_operator(self) -> t.Optional[exp.Expression]: - this = self._parse_ordered(self._parse_opclass) - - if not self._match(TokenType.WITH): - return this - - op = self._parse_var(any_token=True) - - return self.expression(exp.WithOperator, this=this, op=op) - - def _parse_wrapped_options(self) -> t.List[t.Optional[exp.Expression]]: - self._match(TokenType.EQ) - self._match(TokenType.L_PAREN) - - opts: t.List[t.Optional[exp.Expression]] = [] - option: exp.Expression | None - while self._curr and not self._match(TokenType.R_PAREN): - if self._match_text_seq("FORMAT_NAME", "="): - # The FORMAT_NAME can be set to an identifier for Snowflake and T-SQL - option = self._parse_format_name() - else: - option = self._parse_property() - - if option is None: - self.raise_error("Unable to parse option") - break - - opts.append(option) - - return opts - - def _parse_copy_parameters(self) -> t.List[exp.CopyParameter]: - sep = TokenType.COMMA if self.dialect.COPY_PARAMS_ARE_CSV else None - - options = [] - while self._curr and not self._match(TokenType.R_PAREN, advance=False): - option = self._parse_var(any_token=True) - prev = self._prev.text.upper() - - # Different dialects might separate options and values by white space, "=" and "AS" - self._match(TokenType.EQ) - self._match(TokenType.ALIAS) - - param = self.expression(exp.CopyParameter, this=option) - - if prev in self.COPY_INTO_VARLEN_OPTIONS and self._match( - TokenType.L_PAREN, advance=False - ): - # Snowflake FILE_FORMAT case, Databricks COPY & FORMAT options - param.set("expressions", self._parse_wrapped_options()) - elif prev == "FILE_FORMAT": - # T-SQL's external file format case - param.set("expression", self._parse_field()) - else: - param.set("expression", self._parse_unquoted_field()) - - options.append(param) - self._match(sep) - - return options - - def _parse_credentials(self) -> t.Optional[exp.Credentials]: - expr = self.expression(exp.Credentials) - - if self._match_text_seq("STORAGE_INTEGRATION", "="): - expr.set("storage", self._parse_field()) - if self._match_text_seq("CREDENTIALS"): - # Snowflake case: CREDENTIALS = (...), Redshift case: CREDENTIALS - creds = ( - self._parse_wrapped_options() if self._match(TokenType.EQ) else self._parse_field() - ) - expr.set("credentials", creds) - if self._match_text_seq("ENCRYPTION"): - expr.set("encryption", self._parse_wrapped_options()) - if self._match_text_seq("IAM_ROLE"): - expr.set("iam_role", self._parse_field()) - if self._match_text_seq("REGION"): - expr.set("region", self._parse_field()) - - return expr - - def _parse_file_location(self) -> t.Optional[exp.Expression]: - return self._parse_field() - - def _parse_copy(self) -> exp.Copy | exp.Command: - start = self._prev - - self._match(TokenType.INTO) - - this = ( - self._parse_select(nested=True, parse_subquery_alias=False) - if self._match(TokenType.L_PAREN, advance=False) - else self._parse_table(schema=True) - ) - - kind = self._match(TokenType.FROM) or not self._match_text_seq("TO") - - files = self._parse_csv(self._parse_file_location) - credentials = self._parse_credentials() - - self._match_text_seq("WITH") - - params = self._parse_wrapped(self._parse_copy_parameters, optional=True) - - # Fallback case - if self._curr: - return self._parse_as_command(start) - - return self.expression( - exp.Copy, - this=this, - kind=kind, - credentials=credentials, - files=files, - params=params, - ) - - def _parse_normalize(self) -> exp.Normalize: - return self.expression( - exp.Normalize, - this=self._parse_bitwise(), - form=self._match(TokenType.COMMA) and self._parse_var(), - ) - - def _parse_ceil_floor(self, expr_type: t.Type[TCeilFloor]) -> TCeilFloor: - args = self._parse_csv(lambda: self._parse_lambda()) - - this = seq_get(args, 0) - decimals = seq_get(args, 1) - - return expr_type( - this=this, decimals=decimals, to=self._match_text_seq("TO") and self._parse_var() - ) - - def _parse_star_ops(self) -> t.Optional[exp.Expression]: - star_token = self._prev - - if self._match_text_seq("COLUMNS", "(", advance=False): - this = self._parse_function() - if isinstance(this, exp.Columns): - this.set("unpack", True) - return this - - return self.expression( - exp.Star, - **{ # type: ignore - "except": self._parse_star_op("EXCEPT", "EXCLUDE"), - "replace": self._parse_star_op("REPLACE"), - "rename": self._parse_star_op("RENAME"), - }, - ).update_positions(star_token) - - def _parse_grant_privilege(self) -> t.Optional[exp.GrantPrivilege]: - privilege_parts = [] - - # Keep consuming consecutive keywords until comma (end of this privilege) or ON - # (end of privilege list) or L_PAREN (start of column list) are met - while self._curr and not self._match_set(self.PRIVILEGE_FOLLOW_TOKENS, advance=False): - privilege_parts.append(self._curr.text.upper()) - self._advance() - - this = exp.var(" ".join(privilege_parts)) - expressions = ( - self._parse_wrapped_csv(self._parse_column) - if self._match(TokenType.L_PAREN, advance=False) - else None - ) - - return self.expression(exp.GrantPrivilege, this=this, expressions=expressions) - - def _parse_grant_principal(self) -> t.Optional[exp.GrantPrincipal]: - kind = self._match_texts(("ROLE", "GROUP")) and self._prev.text.upper() - principal = self._parse_id_var() - - if not principal: - return None - - return self.expression(exp.GrantPrincipal, this=principal, kind=kind) - - def _parse_grant(self) -> exp.Grant | exp.Command: - start = self._prev - - privileges = self._parse_csv(self._parse_grant_privilege) - - self._match(TokenType.ON) - kind = self._match_set(self.CREATABLES) and self._prev.text.upper() - - # Attempt to parse the securable e.g. MySQL allows names - # such as "foo.*", "*.*" which are not easily parseable yet - securable = self._try_parse(self._parse_table_parts) - - if not securable or not self._match_text_seq("TO"): - return self._parse_as_command(start) - - principals = self._parse_csv(self._parse_grant_principal) - - grant_option = self._match_text_seq("WITH", "GRANT", "OPTION") - - if self._curr: - return self._parse_as_command(start) - - return self.expression( - exp.Grant, - privileges=privileges, - kind=kind, - securable=securable, - principals=principals, - grant_option=grant_option, - ) - - def _parse_overlay(self) -> exp.Overlay: - return self.expression( - exp.Overlay, - **{ # type: ignore - "this": self._parse_bitwise(), - "expression": self._match_text_seq("PLACING") and self._parse_bitwise(), - "from": self._match_text_seq("FROM") and self._parse_bitwise(), - "for": self._match_text_seq("FOR") and self._parse_bitwise(), - }, - ) - - def _parse_format_name(self) -> exp.Property: - # Note: Although not specified in the docs, Snowflake does accept a string/identifier - # for FILE_FORMAT = - return self.expression( - exp.Property, - this=exp.var("FORMAT_NAME"), - value=self._parse_string() or self._parse_table_parts(), - ) - - def _parse_max_min_by(self, expr_type: t.Type[exp.AggFunc]) -> exp.AggFunc: - args: t.List[exp.Expression] = [] - - if self._match(TokenType.DISTINCT): - args.append(self.expression(exp.Distinct, expressions=[self._parse_assignment()])) - self._match(TokenType.COMMA) - - args.extend(self._parse_csv(self._parse_assignment)) - - return self.expression( - expr_type, this=seq_get(args, 0), expression=seq_get(args, 1), count=seq_get(args, 2) - ) - - def _identifier_expression( - self, token: t.Optional[Token] = None, **kwargs: t.Any - ) -> exp.Identifier: - token = token or self._prev - expression = self.expression(exp.Identifier, this=token.text, **kwargs) - expression.update_positions(token) - return expression - - def _build_pipe_cte( - self, - query: exp.Query, - expressions: t.List[exp.Expression], - alias_cte: t.Optional[exp.TableAlias] = None, - ) -> exp.Select: - new_cte: t.Optional[t.Union[str, exp.TableAlias]] - if alias_cte: - new_cte = alias_cte - else: - self._pipe_cte_counter += 1 - new_cte = f"__tmp{self._pipe_cte_counter}" - - with_ = query.args.get("with") - ctes = with_.pop() if with_ else None - - new_select = exp.select(*expressions, copy=False).from_(new_cte, copy=False) - if ctes: - new_select.set("with", ctes) - - return new_select.with_(new_cte, as_=query, copy=False) - - def _parse_pipe_syntax_select(self, query: exp.Select) -> exp.Select: - select = self._parse_select() - if not select: - return query - - return self._build_pipe_cte(query.select(*select.expressions, append=False), [exp.Star()]) - - def _parse_pipe_syntax_limit(self, query: exp.Select) -> exp.Select: - limit = self._parse_limit() - offset = self._parse_offset() - if limit: - curr_limit = query.args.get("limit", limit) - if curr_limit.expression.to_py() >= limit.expression.to_py(): - query.limit(limit, copy=False) - if offset: - curr_offset = query.args.get("offset") - curr_offset = curr_offset.expression.to_py() if curr_offset else 0 - query.offset(exp.Literal.number(curr_offset + offset.expression.to_py()), copy=False) - - return query - - def _parse_pipe_syntax_aggregate_fields(self) -> t.Optional[exp.Expression]: - this = self._parse_assignment() - if self._match_text_seq("GROUP", "AND", advance=False): - return this - - this = self._parse_alias(this) - - if self._match_set((TokenType.ASC, TokenType.DESC), advance=False): - return self._parse_ordered(lambda: this) - - return this - - def _parse_pipe_syntax_aggregate_group_order_by( - self, query: exp.Select, group_by_exists: bool = True - ) -> exp.Select: - expr = self._parse_csv(self._parse_pipe_syntax_aggregate_fields) - aggregates_or_groups, orders = [], [] - for element in expr: - if isinstance(element, exp.Ordered): - this = element.this - if isinstance(this, exp.Alias): - element.set("this", this.args["alias"]) - orders.append(element) - else: - this = element - aggregates_or_groups.append(this) - - if group_by_exists: - query.select(*aggregates_or_groups, copy=False).group_by( - *[projection.args.get("alias", projection) for projection in aggregates_or_groups], - copy=False, - ) - else: - query.select(*aggregates_or_groups, append=False, copy=False) - - if orders: - return query.order_by(*orders, append=False, copy=False) - - return query - - def _parse_pipe_syntax_aggregate(self, query: exp.Select) -> exp.Select: - self._match_text_seq("AGGREGATE") - query = self._parse_pipe_syntax_aggregate_group_order_by(query, group_by_exists=False) - - if self._match(TokenType.GROUP_BY) or ( - self._match_text_seq("GROUP", "AND") and self._match(TokenType.ORDER_BY) - ): - query = self._parse_pipe_syntax_aggregate_group_order_by(query) - - return self._build_pipe_cte(query, [exp.Star()]) - - def _parse_pipe_syntax_set_operator(self, query: exp.Query) -> t.Optional[exp.Select]: - first_setop = self.parse_set_operation(this=query) - if not first_setop: - return None - - def _parse_and_unwrap_query() -> t.Optional[exp.Select]: - expr = self._parse_paren() - return expr.assert_is(exp.Subquery).unnest() if expr else None - - first_setop.this.pop() - - setops = [ - first_setop.expression.pop().assert_is(exp.Subquery).unnest(), - *self._parse_csv(_parse_and_unwrap_query), - ] - - query = self._build_pipe_cte(query, [exp.Star()]) - with_ = query.args.get("with") - ctes = with_.pop() if with_ else None - - if isinstance(first_setop, exp.Union): - query = query.union(*setops, copy=False, **first_setop.args) - elif isinstance(first_setop, exp.Except): - query = query.except_(*setops, copy=False, **first_setop.args) - else: - query = query.intersect(*setops, copy=False, **first_setop.args) - - query.set("with", ctes) - - return self._build_pipe_cte(query, [exp.Star()]) - - def _parse_pipe_syntax_join(self, query: exp.Select) -> t.Optional[exp.Select]: - join = self._parse_join() - if not join: - return None - - return query.join(join, copy=False) - - def _parse_pipe_syntax_pivot(self, query: exp.Select) -> exp.Select: - pivots = self._parse_pivots() - if not pivots: - return query - - from_ = query.args.get("from") - if from_: - from_.this.set("pivots", pivots) - - return self._build_pipe_cte(query, [exp.Star()]) - - def _parse_pipe_syntax_extend(self, query: exp.Select) -> exp.Select: - self._match_text_seq("EXTEND") - query.select(*[exp.Star(), *self._parse_expressions()], append=False, copy=False) - return self._build_pipe_cte(query, [exp.Star()]) - - def _parse_pipe_syntax_query(self, query: exp.Select) -> t.Optional[exp.Select]: - while self._match(TokenType.PIPE_GT): - start = self._curr - parser = self.PIPE_SYNTAX_TRANSFORM_PARSERS.get(self._curr.text.upper()) - if not parser: - # The set operators (UNION, etc) and the JOIN operator have a few common starting - # keywords, making it tricky to disambiguate them without lookahead. The approach - # here is to try and parse a set operation and if that fails, then try to parse a - # join operator. If that fails as well, then the operator is not supported. - parsed_query = self._parse_pipe_syntax_set_operator(query) - parsed_query = parsed_query or self._parse_pipe_syntax_join(query) - if not parsed_query: - self._retreat(start) - self.raise_error(f"Unsupported pipe syntax operator: '{start.text.upper()}'.") - break - query = parsed_query - else: - query = parser(self, query) - - return query diff --git a/altimate_packages/sqlglot/planner.py b/altimate_packages/sqlglot/planner.py deleted file mode 100644 index 687bffb9f..000000000 --- a/altimate_packages/sqlglot/planner.py +++ /dev/null @@ -1,463 +0,0 @@ -from __future__ import annotations - -import math -import typing as t - -from sqlglot import alias, exp -from sqlglot.helper import name_sequence -from sqlglot.optimizer.eliminate_joins import join_condition - - -class Plan: - def __init__(self, expression: exp.Expression) -> None: - self.expression = expression.copy() - self.root = Step.from_expression(self.expression) - self._dag: t.Dict[Step, t.Set[Step]] = {} - - @property - def dag(self) -> t.Dict[Step, t.Set[Step]]: - if not self._dag: - dag: t.Dict[Step, t.Set[Step]] = {} - nodes = {self.root} - - while nodes: - node = nodes.pop() - dag[node] = set() - - for dep in node.dependencies: - dag[node].add(dep) - nodes.add(dep) - - self._dag = dag - - return self._dag - - @property - def leaves(self) -> t.Iterator[Step]: - return (node for node, deps in self.dag.items() if not deps) - - def __repr__(self) -> str: - return f"Plan\n----\n{repr(self.root)}" - - -class Step: - @classmethod - def from_expression( - cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None - ) -> Step: - """ - Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine. - Note: the expression's tables and subqueries must be aliased for this method to work. For - example, given the following expression: - - SELECT - x.a, - SUM(x.b) - FROM x AS x - JOIN y AS y - ON x.a = y.a - GROUP BY x.a - - the following DAG is produced (the expression IDs might differ per execution): - - - Aggregate: x (4347984624) - Context: - Aggregations: - - SUM(x.b) - Group: - - x.a - Projections: - - x.a - - "x"."" - Dependencies: - - Join: x (4347985296) - Context: - y: - On: x.a = y.a - Projections: - Dependencies: - - Scan: x (4347983136) - Context: - Source: x AS x - Projections: - - Scan: y (4343416624) - Context: - Source: y AS y - Projections: - - Args: - expression: the expression to build the DAG from. - ctes: a dictionary that maps CTEs to their corresponding Step DAG by name. - - Returns: - A Step DAG corresponding to `expression`. - """ - ctes = ctes or {} - expression = expression.unnest() - with_ = expression.args.get("with") - - # CTEs break the mold of scope and introduce themselves to all in the context. - if with_: - ctes = ctes.copy() - for cte in with_.expressions: - step = Step.from_expression(cte.this, ctes) - step.name = cte.alias - ctes[step.name] = step # type: ignore - - from_ = expression.args.get("from") - - if isinstance(expression, exp.Select) and from_: - step = Scan.from_expression(from_.this, ctes) - elif isinstance(expression, exp.SetOperation): - step = SetOperation.from_expression(expression, ctes) - else: - step = Scan() - - joins = expression.args.get("joins") - - if joins: - join = Join.from_joins(joins, ctes) - join.name = step.name - join.source_name = step.name - join.add_dependency(step) - step = join - - projections = [] # final selects in this chain of steps representing a select - operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1) - aggregations = {} - next_operand_name = name_sequence("_a_") - - def extract_agg_operands(expression): - agg_funcs = tuple(expression.find_all(exp.AggFunc)) - if agg_funcs: - aggregations[expression] = None - - for agg in agg_funcs: - for operand in agg.unnest_operands(): - if isinstance(operand, exp.Column): - continue - if operand not in operands: - operands[operand] = next_operand_name() - - operand.replace(exp.column(operands[operand], quoted=True)) - - return bool(agg_funcs) - - def set_ops_and_aggs(step): - step.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items()) - step.aggregations = list(aggregations) - - for e in expression.expressions: - if e.find(exp.AggFunc): - projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) - extract_agg_operands(e) - else: - projections.append(e) - - where = expression.args.get("where") - - if where: - step.condition = where.this - - group = expression.args.get("group") - - if group or aggregations: - aggregate = Aggregate() - aggregate.source = step.name - aggregate.name = step.name - - having = expression.args.get("having") - - if having: - if extract_agg_operands(exp.alias_(having.this, "_h", quoted=True)): - aggregate.condition = exp.column("_h", step.name, quoted=True) - else: - aggregate.condition = having.this - - set_ops_and_aggs(aggregate) - - # give aggregates names and replace projections with references to them - aggregate.group = { - f"_g{i}": e for i, e in enumerate(group.expressions if group else []) - } - - intermediate: t.Dict[str | exp.Expression, str] = {} - for k, v in aggregate.group.items(): - intermediate[v] = k - if isinstance(v, exp.Column): - intermediate[v.name] = k - - for projection in projections: - for node in projection.walk(): - name = intermediate.get(node) - if name: - node.replace(exp.column(name, step.name)) - - if aggregate.condition: - for node in aggregate.condition.walk(): - name = intermediate.get(node) or intermediate.get(node.name) - if name: - node.replace(exp.column(name, step.name)) - - aggregate.add_dependency(step) - step = aggregate - else: - aggregate = None - - order = expression.args.get("order") - - if order: - if aggregate and isinstance(step, Aggregate): - for i, ordered in enumerate(order.expressions): - if extract_agg_operands(exp.alias_(ordered.this, f"_o_{i}", quoted=True)): - ordered.this.replace(exp.column(f"_o_{i}", step.name, quoted=True)) - - set_ops_and_aggs(aggregate) - - sort = Sort() - sort.name = step.name - sort.key = order.expressions - sort.add_dependency(step) - step = sort - - step.projections = projections - - if isinstance(expression, exp.Select) and expression.args.get("distinct"): - distinct = Aggregate() - distinct.source = step.name - distinct.name = step.name - distinct.group = { - e.alias_or_name: exp.column(col=e.alias_or_name, table=step.name) - for e in projections or expression.expressions - } - distinct.add_dependency(step) - step = distinct - - limit = expression.args.get("limit") - - if limit: - step.limit = int(limit.text("expression")) - - return step - - def __init__(self) -> None: - self.name: t.Optional[str] = None - self.dependencies: t.Set[Step] = set() - self.dependents: t.Set[Step] = set() - self.projections: t.Sequence[exp.Expression] = [] - self.limit: float = math.inf - self.condition: t.Optional[exp.Expression] = None - - def add_dependency(self, dependency: Step) -> None: - self.dependencies.add(dependency) - dependency.dependents.add(self) - - def __repr__(self) -> str: - return self.to_s() - - def to_s(self, level: int = 0) -> str: - indent = " " * level - nested = f"{indent} " - - context = self._to_s(f"{nested} ") - - if context: - context = [f"{nested}Context:"] + context - - lines = [ - f"{indent}- {self.id}", - *context, - f"{nested}Projections:", - ] - - for expression in self.projections: - lines.append(f"{nested} - {expression.sql()}") - - if self.condition: - lines.append(f"{nested}Condition: {self.condition.sql()}") - - if self.limit is not math.inf: - lines.append(f"{nested}Limit: {self.limit}") - - if self.dependencies: - lines.append(f"{nested}Dependencies:") - for dependency in self.dependencies: - lines.append(" " + dependency.to_s(level + 1)) - - return "\n".join(lines) - - @property - def type_name(self) -> str: - return self.__class__.__name__ - - @property - def id(self) -> str: - name = self.name - name = f" {name}" if name else "" - return f"{self.type_name}:{name} ({id(self)})" - - def _to_s(self, _indent: str) -> t.List[str]: - return [] - - -class Scan(Step): - @classmethod - def from_expression( - cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None - ) -> Step: - table = expression - alias_ = expression.alias_or_name - - if isinstance(expression, exp.Subquery): - table = expression.this - step = Step.from_expression(table, ctes) - step.name = alias_ - return step - - step = Scan() - step.name = alias_ - step.source = expression - if ctes and table.name in ctes: - step.add_dependency(ctes[table.name]) - - return step - - def __init__(self) -> None: - super().__init__() - self.source: t.Optional[exp.Expression] = None - - def _to_s(self, indent: str) -> t.List[str]: - return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"] # type: ignore - - -class Join(Step): - @classmethod - def from_joins( - cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None - ) -> Join: - step = Join() - - for join in joins: - source_key, join_key, condition = join_condition(join) - step.joins[join.alias_or_name] = { - "side": join.side, # type: ignore - "join_key": join_key, - "source_key": source_key, - "condition": condition, - } - - step.add_dependency(Scan.from_expression(join.this, ctes)) - - return step - - def __init__(self) -> None: - super().__init__() - self.source_name: t.Optional[str] = None - self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {} - - def _to_s(self, indent: str) -> t.List[str]: - lines = [f"{indent}Source: {self.source_name or self.name}"] - for name, join in self.joins.items(): - lines.append(f"{indent}{name}: {join['side'] or 'INNER'}") - join_key = ", ".join(str(key) for key in t.cast(list, join.get("join_key") or [])) - if join_key: - lines.append(f"{indent}Key: {join_key}") - if join.get("condition"): - lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore - return lines - - -class Aggregate(Step): - def __init__(self) -> None: - super().__init__() - self.aggregations: t.List[exp.Expression] = [] - self.operands: t.Tuple[exp.Expression, ...] = () - self.group: t.Dict[str, exp.Expression] = {} - self.source: t.Optional[str] = None - - def _to_s(self, indent: str) -> t.List[str]: - lines = [f"{indent}Aggregations:"] - - for expression in self.aggregations: - lines.append(f"{indent} - {expression.sql()}") - - if self.group: - lines.append(f"{indent}Group:") - for expression in self.group.values(): - lines.append(f"{indent} - {expression.sql()}") - if self.condition: - lines.append(f"{indent}Having:") - lines.append(f"{indent} - {self.condition.sql()}") - if self.operands: - lines.append(f"{indent}Operands:") - for expression in self.operands: - lines.append(f"{indent} - {expression.sql()}") - - return lines - - -class Sort(Step): - def __init__(self) -> None: - super().__init__() - self.key = None - - def _to_s(self, indent: str) -> t.List[str]: - lines = [f"{indent}Key:"] - - for expression in self.key: # type: ignore - lines.append(f"{indent} - {expression.sql()}") - - return lines - - -class SetOperation(Step): - def __init__( - self, - op: t.Type[exp.Expression], - left: str | None, - right: str | None, - distinct: bool = False, - ) -> None: - super().__init__() - self.op = op - self.left = left - self.right = right - self.distinct = distinct - - @classmethod - def from_expression( - cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None - ) -> SetOperation: - assert isinstance(expression, exp.SetOperation) - - left = Step.from_expression(expression.left, ctes) - # SELECT 1 UNION SELECT 2 <-- these subqueries don't have names - left.name = left.name or "left" - right = Step.from_expression(expression.right, ctes) - right.name = right.name or "right" - step = cls( - op=expression.__class__, - left=left.name, - right=right.name, - distinct=bool(expression.args.get("distinct")), - ) - - step.add_dependency(left) - step.add_dependency(right) - - limit = expression.args.get("limit") - - if limit: - step.limit = int(limit.text("expression")) - - return step - - def _to_s(self, indent: str) -> t.List[str]: - lines = [] - if self.distinct: - lines.append(f"{indent}Distinct: {self.distinct}") - return lines - - @property - def type_name(self) -> str: - return self.op.__name__ diff --git a/altimate_packages/sqlglot/py.typed b/altimate_packages/sqlglot/py.typed deleted file mode 100644 index e69de29bb..000000000 diff --git a/altimate_packages/sqlglot/schema.py b/altimate_packages/sqlglot/schema.py deleted file mode 100644 index c8362e081..000000000 --- a/altimate_packages/sqlglot/schema.py +++ /dev/null @@ -1,588 +0,0 @@ -from __future__ import annotations - -import abc -import typing as t - -from sqlglot import expressions as exp -from sqlglot.dialects.dialect import Dialect -from sqlglot.errors import SchemaError -from sqlglot.helper import dict_depth, first -from sqlglot.trie import TrieResult, in_trie, new_trie - -if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import DialectType - - ColumnMapping = t.Union[t.Dict, str, t.List] - - -class Schema(abc.ABC): - """Abstract base class for database schemas""" - - dialect: DialectType - - @abc.abstractmethod - def add_table( - self, - table: exp.Table | str, - column_mapping: t.Optional[ColumnMapping] = None, - dialect: DialectType = None, - normalize: t.Optional[bool] = None, - match_depth: bool = True, - ) -> None: - """ - Register or update a table. Some implementing classes may require column information to also be provided. - The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. - - Args: - table: the `Table` expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table. - dialect: the SQL dialect that will be used to parse `table` if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest. - match_depth: whether to enforce that the table must match the schema's depth or not. - """ - - @abc.abstractmethod - def column_names( - self, - table: exp.Table | str, - only_visible: bool = False, - dialect: DialectType = None, - normalize: t.Optional[bool] = None, - ) -> t.Sequence[str]: - """ - Get the column names for a table. - - Args: - table: the `Table` expression instance. - only_visible: whether to include invisible columns. - dialect: the SQL dialect that will be used to parse `table` if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest. - - Returns: - The sequence of column names. - """ - - @abc.abstractmethod - def get_column_type( - self, - table: exp.Table | str, - column: exp.Column | str, - dialect: DialectType = None, - normalize: t.Optional[bool] = None, - ) -> exp.DataType: - """ - Get the `sqlglot.exp.DataType` type of a column in the schema. - - Args: - table: the source table. - column: the target column. - dialect: the SQL dialect that will be used to parse `table` if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest. - - Returns: - The resulting column type. - """ - - def has_column( - self, - table: exp.Table | str, - column: exp.Column | str, - dialect: DialectType = None, - normalize: t.Optional[bool] = None, - ) -> bool: - """ - Returns whether `column` appears in `table`'s schema. - - Args: - table: the source table. - column: the target column. - dialect: the SQL dialect that will be used to parse `table` if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest. - - Returns: - True if the column appears in the schema, False otherwise. - """ - name = column if isinstance(column, str) else column.name - return name in self.column_names(table, dialect=dialect, normalize=normalize) - - @property - @abc.abstractmethod - def supported_table_args(self) -> t.Tuple[str, ...]: - """ - Table arguments this schema support, e.g. `("this", "db", "catalog")` - """ - - @property - def empty(self) -> bool: - """Returns whether the schema is empty.""" - return True - - -class AbstractMappingSchema: - def __init__( - self, - mapping: t.Optional[t.Dict] = None, - ) -> None: - self.mapping = mapping or {} - self.mapping_trie = new_trie( - tuple(reversed(t)) for t in flatten_schema(self.mapping, depth=self.depth()) - ) - self._supported_table_args: t.Tuple[str, ...] = tuple() - - @property - def empty(self) -> bool: - return not self.mapping - - def depth(self) -> int: - return dict_depth(self.mapping) - - @property - def supported_table_args(self) -> t.Tuple[str, ...]: - if not self._supported_table_args and self.mapping: - depth = self.depth() - - if not depth: # None - self._supported_table_args = tuple() - elif 1 <= depth <= 3: - self._supported_table_args = exp.TABLE_PARTS[:depth] - else: - raise SchemaError(f"Invalid mapping shape. Depth: {depth}") - - return self._supported_table_args - - def table_parts(self, table: exp.Table) -> t.List[str]: - return [part.name for part in reversed(table.parts)] - - def find( - self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False - ) -> t.Optional[t.Any]: - """ - Returns the schema of a given table. - - Args: - table: the target table. - raise_on_missing: whether to raise in case the schema is not found. - ensure_data_types: whether to convert `str` types to their `DataType` equivalents. - - Returns: - The schema of the target table. - """ - parts = self.table_parts(table)[0 : len(self.supported_table_args)] - value, trie = in_trie(self.mapping_trie, parts) - - if value == TrieResult.FAILED: - return None - - if value == TrieResult.PREFIX: - possibilities = flatten_schema(trie) - - if len(possibilities) == 1: - parts.extend(possibilities[0]) - else: - message = ", ".join(".".join(parts) for parts in possibilities) - if raise_on_missing: - raise SchemaError(f"Ambiguous mapping for {table}: {message}.") - return None - - return self.nested_get(parts, raise_on_missing=raise_on_missing) - - def nested_get( - self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True - ) -> t.Optional[t.Any]: - return nested_get( - d or self.mapping, - *zip(self.supported_table_args, reversed(parts)), - raise_on_missing=raise_on_missing, - ) - - -class MappingSchema(AbstractMappingSchema, Schema): - """ - Schema based on a nested mapping. - - Args: - schema: Mapping in one of the following forms: - 1. {table: {col: type}} - 2. {db: {table: {col: type}}} - 3. {catalog: {db: {table: {col: type}}}} - 4. None - Tables will be added later - visible: Optional mapping of which columns in the schema are visible. If not provided, all columns - are assumed to be visible. The nesting should mirror that of the schema: - 1. {table: set(*cols)}} - 2. {db: {table: set(*cols)}}} - 3. {catalog: {db: {table: set(*cols)}}}} - dialect: The dialect to be used for custom type mappings & parsing string arguments. - normalize: Whether to normalize identifier names according to the given dialect or not. - """ - - def __init__( - self, - schema: t.Optional[t.Dict] = None, - visible: t.Optional[t.Dict] = None, - dialect: DialectType = None, - normalize: bool = True, - ) -> None: - self.dialect = dialect - self.visible = {} if visible is None else visible - self.normalize = normalize - self._type_mapping_cache: t.Dict[str, exp.DataType] = {} - self._depth = 0 - schema = {} if schema is None else schema - - super().__init__(self._normalize(schema) if self.normalize else schema) - - @classmethod - def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: - return MappingSchema( - schema=mapping_schema.mapping, - visible=mapping_schema.visible, - dialect=mapping_schema.dialect, - normalize=mapping_schema.normalize, - ) - - def find( - self, table: exp.Table, raise_on_missing: bool = True, ensure_data_types: bool = False - ) -> t.Optional[t.Any]: - schema = super().find( - table, raise_on_missing=raise_on_missing, ensure_data_types=ensure_data_types - ) - if ensure_data_types and isinstance(schema, dict): - schema = { - col: self._to_data_type(dtype) if isinstance(dtype, str) else dtype - for col, dtype in schema.items() - } - - return schema - - def copy(self, **kwargs) -> MappingSchema: - return MappingSchema( - **{ # type: ignore - "schema": self.mapping.copy(), - "visible": self.visible.copy(), - "dialect": self.dialect, - "normalize": self.normalize, - **kwargs, - } - ) - - def add_table( - self, - table: exp.Table | str, - column_mapping: t.Optional[ColumnMapping] = None, - dialect: DialectType = None, - normalize: t.Optional[bool] = None, - match_depth: bool = True, - ) -> None: - """ - Register or update a table. Updates are only performed if a new column mapping is provided. - The added table must have the necessary number of qualifiers in its path to match the schema's nesting level. - - Args: - table: the `Table` expression instance or string representing the table. - column_mapping: a column mapping that describes the structure of the table. - dialect: the SQL dialect that will be used to parse `table` if it's a string. - normalize: whether to normalize identifiers according to the dialect of interest. - match_depth: whether to enforce that the table must match the schema's depth or not. - """ - normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) - - if match_depth and not self.empty and len(normalized_table.parts) != self.depth(): - raise SchemaError( - f"Table {normalized_table.sql(dialect=self.dialect)} must match the " - f"schema's nesting level: {self.depth()}." - ) - - normalized_column_mapping = { - self._normalize_name(key, dialect=dialect, normalize=normalize): value - for key, value in ensure_column_mapping(column_mapping).items() - } - - schema = self.find(normalized_table, raise_on_missing=False) - if schema and not normalized_column_mapping: - return - - parts = self.table_parts(normalized_table) - - nested_set(self.mapping, tuple(reversed(parts)), normalized_column_mapping) - new_trie([parts], self.mapping_trie) - - def column_names( - self, - table: exp.Table | str, - only_visible: bool = False, - dialect: DialectType = None, - normalize: t.Optional[bool] = None, - ) -> t.List[str]: - normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) - - schema = self.find(normalized_table) - if schema is None: - return [] - - if not only_visible or not self.visible: - return list(schema) - - visible = self.nested_get(self.table_parts(normalized_table), self.visible) or [] - return [col for col in schema if col in visible] - - def get_column_type( - self, - table: exp.Table | str, - column: exp.Column | str, - dialect: DialectType = None, - normalize: t.Optional[bool] = None, - ) -> exp.DataType: - normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) - - normalized_column_name = self._normalize_name( - column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize - ) - - table_schema = self.find(normalized_table, raise_on_missing=False) - if table_schema: - column_type = table_schema.get(normalized_column_name) - - if isinstance(column_type, exp.DataType): - return column_type - elif isinstance(column_type, str): - return self._to_data_type(column_type, dialect=dialect) - - return exp.DataType.build("unknown") - - def has_column( - self, - table: exp.Table | str, - column: exp.Column | str, - dialect: DialectType = None, - normalize: t.Optional[bool] = None, - ) -> bool: - normalized_table = self._normalize_table(table, dialect=dialect, normalize=normalize) - - normalized_column_name = self._normalize_name( - column if isinstance(column, str) else column.this, dialect=dialect, normalize=normalize - ) - - table_schema = self.find(normalized_table, raise_on_missing=False) - return normalized_column_name in table_schema if table_schema else False - - def _normalize(self, schema: t.Dict) -> t.Dict: - """ - Normalizes all identifiers in the schema. - - Args: - schema: the schema to normalize. - - Returns: - The normalized schema mapping. - """ - normalized_mapping: t.Dict = {} - flattened_schema = flatten_schema(schema) - error_msg = "Table {} must match the schema's nesting level: {}." - - for keys in flattened_schema: - columns = nested_get(schema, *zip(keys, keys)) - - if not isinstance(columns, dict): - raise SchemaError(error_msg.format(".".join(keys[:-1]), len(flattened_schema[0]))) - if not columns: - raise SchemaError(f"Table {'.'.join(keys[:-1])} must have at least one column") - if isinstance(first(columns.values()), dict): - raise SchemaError( - error_msg.format( - ".".join(keys + flatten_schema(columns)[0]), len(flattened_schema[0]) - ), - ) - - normalized_keys = [self._normalize_name(key, is_table=True) for key in keys] - for column_name, column_type in columns.items(): - nested_set( - normalized_mapping, - normalized_keys + [self._normalize_name(column_name)], - column_type, - ) - - return normalized_mapping - - def _normalize_table( - self, - table: exp.Table | str, - dialect: DialectType = None, - normalize: t.Optional[bool] = None, - ) -> exp.Table: - dialect = dialect or self.dialect - normalize = self.normalize if normalize is None else normalize - - normalized_table = exp.maybe_parse(table, into=exp.Table, dialect=dialect, copy=normalize) - - if normalize: - for part in normalized_table.parts: - if isinstance(part, exp.Identifier): - part.replace( - normalize_name(part, dialect=dialect, is_table=True, normalize=normalize) - ) - - return normalized_table - - def _normalize_name( - self, - name: str | exp.Identifier, - dialect: DialectType = None, - is_table: bool = False, - normalize: t.Optional[bool] = None, - ) -> str: - return normalize_name( - name, - dialect=dialect or self.dialect, - is_table=is_table, - normalize=self.normalize if normalize is None else normalize, - ).name - - def depth(self) -> int: - if not self.empty and not self._depth: - # The columns themselves are a mapping, but we don't want to include those - self._depth = super().depth() - 1 - return self._depth - - def _to_data_type(self, schema_type: str, dialect: DialectType = None) -> exp.DataType: - """ - Convert a type represented as a string to the corresponding `sqlglot.exp.DataType` object. - - Args: - schema_type: the type we want to convert. - dialect: the SQL dialect that will be used to parse `schema_type`, if needed. - - Returns: - The resulting expression type. - """ - if schema_type not in self._type_mapping_cache: - dialect = dialect or self.dialect - udt = Dialect.get_or_raise(dialect).SUPPORTS_USER_DEFINED_TYPES - - try: - expression = exp.DataType.build(schema_type, dialect=dialect, udt=udt) - self._type_mapping_cache[schema_type] = expression - except AttributeError: - in_dialect = f" in dialect {dialect}" if dialect else "" - raise SchemaError(f"Failed to build type '{schema_type}'{in_dialect}.") - - return self._type_mapping_cache[schema_type] - - -def normalize_name( - identifier: str | exp.Identifier, - dialect: DialectType = None, - is_table: bool = False, - normalize: t.Optional[bool] = True, -) -> exp.Identifier: - if isinstance(identifier, str): - identifier = exp.parse_identifier(identifier, dialect=dialect) - - if not normalize: - return identifier - - # this is used for normalize_identifier, bigquery has special rules pertaining tables - identifier.meta["is_table"] = is_table - return Dialect.get_or_raise(dialect).normalize_identifier(identifier) - - -def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: - if isinstance(schema, Schema): - return schema - - return MappingSchema(schema, **kwargs) - - -def ensure_column_mapping(mapping: t.Optional[ColumnMapping]) -> t.Dict: - if mapping is None: - return {} - elif isinstance(mapping, dict): - return mapping - elif isinstance(mapping, str): - col_name_type_strs = [x.strip() for x in mapping.split(",")] - return { - name_type_str.split(":")[0].strip(): name_type_str.split(":")[1].strip() - for name_type_str in col_name_type_strs - } - elif isinstance(mapping, list): - return {x.strip(): None for x in mapping} - - raise ValueError(f"Invalid mapping provided: {type(mapping)}") - - -def flatten_schema( - schema: t.Dict, depth: t.Optional[int] = None, keys: t.Optional[t.List[str]] = None -) -> t.List[t.List[str]]: - tables = [] - keys = keys or [] - depth = dict_depth(schema) - 1 if depth is None else depth - - for k, v in schema.items(): - if depth == 1 or not isinstance(v, dict): - tables.append(keys + [k]) - elif depth >= 2: - tables.extend(flatten_schema(v, depth - 1, keys + [k])) - - return tables - - -def nested_get( - d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True -) -> t.Optional[t.Any]: - """ - Get a value for a nested dictionary. - - Args: - d: the dictionary to search. - *path: tuples of (name, key), where: - `key` is the key in the dictionary to get. - `name` is a string to use in the error if `key` isn't found. - - Returns: - The value or None if it doesn't exist. - """ - for name, key in path: - d = d.get(key) # type: ignore - if d is None: - if raise_on_missing: - name = "table" if name == "this" else name - raise ValueError(f"Unknown {name}: {key}") - return None - - return d - - -def nested_set(d: t.Dict, keys: t.Sequence[str], value: t.Any) -> t.Dict: - """ - In-place set a value for a nested dictionary - - Example: - >>> nested_set({}, ["top_key", "second_key"], "value") - {'top_key': {'second_key': 'value'}} - - >>> nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value") - {'top_key': {'third_key': 'third_value', 'second_key': 'value'}} - - Args: - d: dictionary to update. - keys: the keys that makeup the path to `value`. - value: the value to set in the dictionary for the given key path. - - Returns: - The (possibly) updated dictionary. - """ - if not keys: - return d - - if len(keys) == 1: - d[keys[0]] = value - return d - - subd = d - for key in keys[:-1]: - if key not in subd: - subd = subd.setdefault(key, {}) - else: - subd = subd[key] - - subd[keys[-1]] = value - return d diff --git a/altimate_packages/sqlglot/serde.py b/altimate_packages/sqlglot/serde.py deleted file mode 100644 index b01903561..000000000 --- a/altimate_packages/sqlglot/serde.py +++ /dev/null @@ -1,68 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import expressions as exp - -if t.TYPE_CHECKING: - JSON = t.Union[dict, list, str, float, int, bool, None] - Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON] - - -def dump(node: Node) -> JSON: - """ - Recursively dump an AST into a JSON-serializable dict. - """ - if isinstance(node, list): - return [dump(i) for i in node] - if isinstance(node, exp.DataType.Type): - return { - "class": "DataType.Type", - "value": node.value, - } - if isinstance(node, exp.Expression): - klass = node.__class__.__qualname__ - if node.__class__.__module__ != exp.__name__: - klass = f"{node.__module__}.{klass}" - obj: t.Dict = { - "class": klass, - "args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []}, - } - if node.type: - obj["type"] = dump(node.type) - if node.comments: - obj["comments"] = node.comments - if node._meta is not None: - obj["meta"] = node._meta - - return obj - return node - - -def load(obj: JSON) -> Node: - """ - Recursively load a dict (as returned by `dump`) into an AST. - """ - if isinstance(obj, list): - return [load(i) for i in obj] - if isinstance(obj, dict): - class_name = obj["class"] - - if class_name == "DataType.Type": - return exp.DataType.Type(obj["value"]) - - if "." in class_name: - module_path, class_name = class_name.rsplit(".", maxsplit=1) - module = __import__(module_path, fromlist=[class_name]) - else: - module = exp - - klass = getattr(module, class_name) - - expression = klass(**{k: load(v) for k, v in obj["args"].items()}) - expression.type = t.cast(exp.DataType, load(obj.get("type"))) - expression.comments = obj.get("comments") - expression._meta = obj.get("meta") - - return expression - return obj diff --git a/altimate_packages/sqlglot/time.py b/altimate_packages/sqlglot/time.py deleted file mode 100644 index 66daf1684..000000000 --- a/altimate_packages/sqlglot/time.py +++ /dev/null @@ -1,687 +0,0 @@ -import typing as t -import datetime - -# The generic time format is based on python time.strftime. -# https://docs.python.org/3/library/time.html#time.strftime -from sqlglot.trie import TrieResult, in_trie, new_trie - - -def format_time( - string: str, mapping: t.Dict[str, str], trie: t.Optional[t.Dict] = None -) -> t.Optional[str]: - """ - Converts a time string given a mapping. - - Examples: - >>> format_time("%Y", {"%Y": "YYYY"}) - 'YYYY' - - Args: - mapping: dictionary of time format to target time format. - trie: optional trie, can be passed in for performance. - - Returns: - The converted time string. - """ - if not string: - return None - - start = 0 - end = 1 - size = len(string) - trie = trie or new_trie(mapping) - current = trie - chunks = [] - sym = None - - while end <= size: - chars = string[start:end] - result, current = in_trie(current, chars[-1]) - - if result == TrieResult.FAILED: - if sym: - end -= 1 - chars = sym - sym = None - else: - chars = chars[0] - end = start + 1 - - start += len(chars) - chunks.append(chars) - current = trie - elif result == TrieResult.EXISTS: - sym = chars - - end += 1 - - if result != TrieResult.FAILED and end > size: - chunks.append(chars) - - return "".join(mapping.get(chars, chars) for chars in chunks) - - -TIMEZONES = { - tz.lower() - for tz in ( - "Africa/Abidjan", - "Africa/Accra", - "Africa/Addis_Ababa", - "Africa/Algiers", - "Africa/Asmara", - "Africa/Asmera", - "Africa/Bamako", - "Africa/Bangui", - "Africa/Banjul", - "Africa/Bissau", - "Africa/Blantyre", - "Africa/Brazzaville", - "Africa/Bujumbura", - "Africa/Cairo", - "Africa/Casablanca", - "Africa/Ceuta", - "Africa/Conakry", - "Africa/Dakar", - "Africa/Dar_es_Salaam", - "Africa/Djibouti", - "Africa/Douala", - "Africa/El_Aaiun", - "Africa/Freetown", - "Africa/Gaborone", - "Africa/Harare", - "Africa/Johannesburg", - "Africa/Juba", - "Africa/Kampala", - "Africa/Khartoum", - "Africa/Kigali", - "Africa/Kinshasa", - "Africa/Lagos", - "Africa/Libreville", - "Africa/Lome", - "Africa/Luanda", - "Africa/Lubumbashi", - "Africa/Lusaka", - "Africa/Malabo", - "Africa/Maputo", - "Africa/Maseru", - "Africa/Mbabane", - "Africa/Mogadishu", - "Africa/Monrovia", - "Africa/Nairobi", - "Africa/Ndjamena", - "Africa/Niamey", - "Africa/Nouakchott", - "Africa/Ouagadougou", - "Africa/Porto-Novo", - "Africa/Sao_Tome", - "Africa/Timbuktu", - "Africa/Tripoli", - "Africa/Tunis", - "Africa/Windhoek", - "America/Adak", - "America/Anchorage", - "America/Anguilla", - "America/Antigua", - "America/Araguaina", - "America/Argentina/Buenos_Aires", - "America/Argentina/Catamarca", - "America/Argentina/ComodRivadavia", - "America/Argentina/Cordoba", - "America/Argentina/Jujuy", - "America/Argentina/La_Rioja", - "America/Argentina/Mendoza", - "America/Argentina/Rio_Gallegos", - "America/Argentina/Salta", - "America/Argentina/San_Juan", - "America/Argentina/San_Luis", - "America/Argentina/Tucuman", - "America/Argentina/Ushuaia", - "America/Aruba", - "America/Asuncion", - "America/Atikokan", - "America/Atka", - "America/Bahia", - "America/Bahia_Banderas", - "America/Barbados", - "America/Belem", - "America/Belize", - "America/Blanc-Sablon", - "America/Boa_Vista", - "America/Bogota", - "America/Boise", - "America/Buenos_Aires", - "America/Cambridge_Bay", - "America/Campo_Grande", - "America/Cancun", - "America/Caracas", - "America/Catamarca", - "America/Cayenne", - "America/Cayman", - "America/Chicago", - "America/Chihuahua", - "America/Ciudad_Juarez", - "America/Coral_Harbour", - "America/Cordoba", - "America/Costa_Rica", - "America/Creston", - "America/Cuiaba", - "America/Curacao", - "America/Danmarkshavn", - "America/Dawson", - "America/Dawson_Creek", - "America/Denver", - "America/Detroit", - "America/Dominica", - "America/Edmonton", - "America/Eirunepe", - "America/El_Salvador", - "America/Ensenada", - "America/Fort_Nelson", - "America/Fort_Wayne", - "America/Fortaleza", - "America/Glace_Bay", - "America/Godthab", - "America/Goose_Bay", - "America/Grand_Turk", - "America/Grenada", - "America/Guadeloupe", - "America/Guatemala", - "America/Guayaquil", - "America/Guyana", - "America/Halifax", - "America/Havana", - "America/Hermosillo", - "America/Indiana/Indianapolis", - "America/Indiana/Knox", - "America/Indiana/Marengo", - "America/Indiana/Petersburg", - "America/Indiana/Tell_City", - "America/Indiana/Vevay", - "America/Indiana/Vincennes", - "America/Indiana/Winamac", - "America/Indianapolis", - "America/Inuvik", - "America/Iqaluit", - "America/Jamaica", - "America/Jujuy", - "America/Juneau", - "America/Kentucky/Louisville", - "America/Kentucky/Monticello", - "America/Knox_IN", - "America/Kralendijk", - "America/La_Paz", - "America/Lima", - "America/Los_Angeles", - "America/Louisville", - "America/Lower_Princes", - "America/Maceio", - "America/Managua", - "America/Manaus", - "America/Marigot", - "America/Martinique", - "America/Matamoros", - "America/Mazatlan", - "America/Mendoza", - "America/Menominee", - "America/Merida", - "America/Metlakatla", - "America/Mexico_City", - "America/Miquelon", - "America/Moncton", - "America/Monterrey", - "America/Montevideo", - "America/Montreal", - "America/Montserrat", - "America/Nassau", - "America/New_York", - "America/Nipigon", - "America/Nome", - "America/Noronha", - "America/North_Dakota/Beulah", - "America/North_Dakota/Center", - "America/North_Dakota/New_Salem", - "America/Nuuk", - "America/Ojinaga", - "America/Panama", - "America/Pangnirtung", - "America/Paramaribo", - "America/Phoenix", - "America/Port-au-Prince", - "America/Port_of_Spain", - "America/Porto_Acre", - "America/Porto_Velho", - "America/Puerto_Rico", - "America/Punta_Arenas", - "America/Rainy_River", - "America/Rankin_Inlet", - "America/Recife", - "America/Regina", - "America/Resolute", - "America/Rio_Branco", - "America/Rosario", - "America/Santa_Isabel", - "America/Santarem", - "America/Santiago", - "America/Santo_Domingo", - "America/Sao_Paulo", - "America/Scoresbysund", - "America/Shiprock", - "America/Sitka", - "America/St_Barthelemy", - "America/St_Johns", - "America/St_Kitts", - "America/St_Lucia", - "America/St_Thomas", - "America/St_Vincent", - "America/Swift_Current", - "America/Tegucigalpa", - "America/Thule", - "America/Thunder_Bay", - "America/Tijuana", - "America/Toronto", - "America/Tortola", - "America/Vancouver", - "America/Virgin", - "America/Whitehorse", - "America/Winnipeg", - "America/Yakutat", - "America/Yellowknife", - "Antarctica/Casey", - "Antarctica/Davis", - "Antarctica/DumontDUrville", - "Antarctica/Macquarie", - "Antarctica/Mawson", - "Antarctica/McMurdo", - "Antarctica/Palmer", - "Antarctica/Rothera", - "Antarctica/South_Pole", - "Antarctica/Syowa", - "Antarctica/Troll", - "Antarctica/Vostok", - "Arctic/Longyearbyen", - "Asia/Aden", - "Asia/Almaty", - "Asia/Amman", - "Asia/Anadyr", - "Asia/Aqtau", - "Asia/Aqtobe", - "Asia/Ashgabat", - "Asia/Ashkhabad", - "Asia/Atyrau", - "Asia/Baghdad", - "Asia/Bahrain", - "Asia/Baku", - "Asia/Bangkok", - "Asia/Barnaul", - "Asia/Beirut", - "Asia/Bishkek", - "Asia/Brunei", - "Asia/Calcutta", - "Asia/Chita", - "Asia/Choibalsan", - "Asia/Chongqing", - "Asia/Chungking", - "Asia/Colombo", - "Asia/Dacca", - "Asia/Damascus", - "Asia/Dhaka", - "Asia/Dili", - "Asia/Dubai", - "Asia/Dushanbe", - "Asia/Famagusta", - "Asia/Gaza", - "Asia/Harbin", - "Asia/Hebron", - "Asia/Ho_Chi_Minh", - "Asia/Hong_Kong", - "Asia/Hovd", - "Asia/Irkutsk", - "Asia/Istanbul", - "Asia/Jakarta", - "Asia/Jayapura", - "Asia/Jerusalem", - "Asia/Kabul", - "Asia/Kamchatka", - "Asia/Karachi", - "Asia/Kashgar", - "Asia/Kathmandu", - "Asia/Katmandu", - "Asia/Khandyga", - "Asia/Kolkata", - "Asia/Krasnoyarsk", - "Asia/Kuala_Lumpur", - "Asia/Kuching", - "Asia/Kuwait", - "Asia/Macao", - "Asia/Macau", - "Asia/Magadan", - "Asia/Makassar", - "Asia/Manila", - "Asia/Muscat", - "Asia/Nicosia", - "Asia/Novokuznetsk", - "Asia/Novosibirsk", - "Asia/Omsk", - "Asia/Oral", - "Asia/Phnom_Penh", - "Asia/Pontianak", - "Asia/Pyongyang", - "Asia/Qatar", - "Asia/Qostanay", - "Asia/Qyzylorda", - "Asia/Rangoon", - "Asia/Riyadh", - "Asia/Saigon", - "Asia/Sakhalin", - "Asia/Samarkand", - "Asia/Seoul", - "Asia/Shanghai", - "Asia/Singapore", - "Asia/Srednekolymsk", - "Asia/Taipei", - "Asia/Tashkent", - "Asia/Tbilisi", - "Asia/Tehran", - "Asia/Tel_Aviv", - "Asia/Thimbu", - "Asia/Thimphu", - "Asia/Tokyo", - "Asia/Tomsk", - "Asia/Ujung_Pandang", - "Asia/Ulaanbaatar", - "Asia/Ulan_Bator", - "Asia/Urumqi", - "Asia/Ust-Nera", - "Asia/Vientiane", - "Asia/Vladivostok", - "Asia/Yakutsk", - "Asia/Yangon", - "Asia/Yekaterinburg", - "Asia/Yerevan", - "Atlantic/Azores", - "Atlantic/Bermuda", - "Atlantic/Canary", - "Atlantic/Cape_Verde", - "Atlantic/Faeroe", - "Atlantic/Faroe", - "Atlantic/Jan_Mayen", - "Atlantic/Madeira", - "Atlantic/Reykjavik", - "Atlantic/South_Georgia", - "Atlantic/St_Helena", - "Atlantic/Stanley", - "Australia/ACT", - "Australia/Adelaide", - "Australia/Brisbane", - "Australia/Broken_Hill", - "Australia/Canberra", - "Australia/Currie", - "Australia/Darwin", - "Australia/Eucla", - "Australia/Hobart", - "Australia/LHI", - "Australia/Lindeman", - "Australia/Lord_Howe", - "Australia/Melbourne", - "Australia/NSW", - "Australia/North", - "Australia/Perth", - "Australia/Queensland", - "Australia/South", - "Australia/Sydney", - "Australia/Tasmania", - "Australia/Victoria", - "Australia/West", - "Australia/Yancowinna", - "Brazil/Acre", - "Brazil/DeNoronha", - "Brazil/East", - "Brazil/West", - "CET", - "CST6CDT", - "Canada/Atlantic", - "Canada/Central", - "Canada/Eastern", - "Canada/Mountain", - "Canada/Newfoundland", - "Canada/Pacific", - "Canada/Saskatchewan", - "Canada/Yukon", - "Chile/Continental", - "Chile/EasterIsland", - "Cuba", - "EET", - "EST", - "EST5EDT", - "Egypt", - "Eire", - "Etc/GMT", - "Etc/GMT+0", - "Etc/GMT+1", - "Etc/GMT+10", - "Etc/GMT+11", - "Etc/GMT+12", - "Etc/GMT+2", - "Etc/GMT+3", - "Etc/GMT+4", - "Etc/GMT+5", - "Etc/GMT+6", - "Etc/GMT+7", - "Etc/GMT+8", - "Etc/GMT+9", - "Etc/GMT-0", - "Etc/GMT-1", - "Etc/GMT-10", - "Etc/GMT-11", - "Etc/GMT-12", - "Etc/GMT-13", - "Etc/GMT-14", - "Etc/GMT-2", - "Etc/GMT-3", - "Etc/GMT-4", - "Etc/GMT-5", - "Etc/GMT-6", - "Etc/GMT-7", - "Etc/GMT-8", - "Etc/GMT-9", - "Etc/GMT0", - "Etc/Greenwich", - "Etc/UCT", - "Etc/UTC", - "Etc/Universal", - "Etc/Zulu", - "Europe/Amsterdam", - "Europe/Andorra", - "Europe/Astrakhan", - "Europe/Athens", - "Europe/Belfast", - "Europe/Belgrade", - "Europe/Berlin", - "Europe/Bratislava", - "Europe/Brussels", - "Europe/Bucharest", - "Europe/Budapest", - "Europe/Busingen", - "Europe/Chisinau", - "Europe/Copenhagen", - "Europe/Dublin", - "Europe/Gibraltar", - "Europe/Guernsey", - "Europe/Helsinki", - "Europe/Isle_of_Man", - "Europe/Istanbul", - "Europe/Jersey", - "Europe/Kaliningrad", - "Europe/Kiev", - "Europe/Kirov", - "Europe/Kyiv", - "Europe/Lisbon", - "Europe/Ljubljana", - "Europe/London", - "Europe/Luxembourg", - "Europe/Madrid", - "Europe/Malta", - "Europe/Mariehamn", - "Europe/Minsk", - "Europe/Monaco", - "Europe/Moscow", - "Europe/Nicosia", - "Europe/Oslo", - "Europe/Paris", - "Europe/Podgorica", - "Europe/Prague", - "Europe/Riga", - "Europe/Rome", - "Europe/Samara", - "Europe/San_Marino", - "Europe/Sarajevo", - "Europe/Saratov", - "Europe/Simferopol", - "Europe/Skopje", - "Europe/Sofia", - "Europe/Stockholm", - "Europe/Tallinn", - "Europe/Tirane", - "Europe/Tiraspol", - "Europe/Ulyanovsk", - "Europe/Uzhgorod", - "Europe/Vaduz", - "Europe/Vatican", - "Europe/Vienna", - "Europe/Vilnius", - "Europe/Volgograd", - "Europe/Warsaw", - "Europe/Zagreb", - "Europe/Zaporozhye", - "Europe/Zurich", - "GB", - "GB-Eire", - "GMT", - "GMT+0", - "GMT-0", - "GMT0", - "Greenwich", - "HST", - "Hongkong", - "Iceland", - "Indian/Antananarivo", - "Indian/Chagos", - "Indian/Christmas", - "Indian/Cocos", - "Indian/Comoro", - "Indian/Kerguelen", - "Indian/Mahe", - "Indian/Maldives", - "Indian/Mauritius", - "Indian/Mayotte", - "Indian/Reunion", - "Iran", - "Israel", - "Jamaica", - "Japan", - "Kwajalein", - "Libya", - "MET", - "MST", - "MST7MDT", - "Mexico/BajaNorte", - "Mexico/BajaSur", - "Mexico/General", - "NZ", - "NZ-CHAT", - "Navajo", - "PRC", - "PST8PDT", - "Pacific/Apia", - "Pacific/Auckland", - "Pacific/Bougainville", - "Pacific/Chatham", - "Pacific/Chuuk", - "Pacific/Easter", - "Pacific/Efate", - "Pacific/Enderbury", - "Pacific/Fakaofo", - "Pacific/Fiji", - "Pacific/Funafuti", - "Pacific/Galapagos", - "Pacific/Gambier", - "Pacific/Guadalcanal", - "Pacific/Guam", - "Pacific/Honolulu", - "Pacific/Johnston", - "Pacific/Kanton", - "Pacific/Kiritimati", - "Pacific/Kosrae", - "Pacific/Kwajalein", - "Pacific/Majuro", - "Pacific/Marquesas", - "Pacific/Midway", - "Pacific/Nauru", - "Pacific/Niue", - "Pacific/Norfolk", - "Pacific/Noumea", - "Pacific/Pago_Pago", - "Pacific/Palau", - "Pacific/Pitcairn", - "Pacific/Pohnpei", - "Pacific/Ponape", - "Pacific/Port_Moresby", - "Pacific/Rarotonga", - "Pacific/Saipan", - "Pacific/Samoa", - "Pacific/Tahiti", - "Pacific/Tarawa", - "Pacific/Tongatapu", - "Pacific/Truk", - "Pacific/Wake", - "Pacific/Wallis", - "Pacific/Yap", - "Poland", - "Portugal", - "ROC", - "ROK", - "Singapore", - "Turkey", - "UCT", - "US/Alaska", - "US/Aleutian", - "US/Arizona", - "US/Central", - "US/East-Indiana", - "US/Eastern", - "US/Hawaii", - "US/Indiana-Starke", - "US/Michigan", - "US/Mountain", - "US/Pacific", - "US/Samoa", - "UTC", - "Universal", - "W-SU", - "WET", - "Zulu", - ) -} - - -def subsecond_precision(timestamp_literal: str) -> int: - """ - Given an ISO-8601 timestamp literal, eg '2023-01-01 12:13:14.123456+00:00' - figure out its subsecond precision so we can construct types like DATETIME(6) - - Note that in practice, this is either 3 or 6 digits (3 = millisecond precision, 6 = microsecond precision) - - 6 is the maximum because strftime's '%f' formats to microseconds and almost every database supports microsecond precision in timestamps - - Except Presto/Trino which in most cases only supports millisecond precision but will still honour '%f' and format to microseconds (replacing the remaining 3 digits with 0's) - - Python prior to 3.11 only supports 0, 3 or 6 digits in a timestamp literal. Any other amounts will throw a 'ValueError: Invalid isoformat string:' error - """ - try: - parsed = datetime.datetime.fromisoformat(timestamp_literal) - subsecond_digit_count = len(str(parsed.microsecond).rstrip("0")) - precision = 0 - if subsecond_digit_count > 3: - precision = 6 - elif subsecond_digit_count > 0: - precision = 3 - return precision - except ValueError: - return 0 diff --git a/altimate_packages/sqlglot/tokens.py b/altimate_packages/sqlglot/tokens.py deleted file mode 100644 index 0a5667551..000000000 --- a/altimate_packages/sqlglot/tokens.py +++ /dev/null @@ -1,1520 +0,0 @@ -from __future__ import annotations - -import os -import typing as t -from enum import auto - -from sqlglot.errors import SqlglotError, TokenError -from sqlglot.helper import AutoName -from sqlglot.trie import TrieResult, in_trie, new_trie - -if t.TYPE_CHECKING: - from sqlglot.dialects.dialect import DialectType - - -try: - from sqlglotrs import ( # type: ignore - Tokenizer as RsTokenizer, - TokenizerDialectSettings as RsTokenizerDialectSettings, - TokenizerSettings as RsTokenizerSettings, - TokenTypeSettings as RsTokenTypeSettings, - ) - - USE_RS_TOKENIZER = os.environ.get("SQLGLOTRS_TOKENIZER", "1") == "1" -except ImportError: - USE_RS_TOKENIZER = False - - -class TokenType(AutoName): - L_PAREN = auto() - R_PAREN = auto() - L_BRACKET = auto() - R_BRACKET = auto() - L_BRACE = auto() - R_BRACE = auto() - COMMA = auto() - DOT = auto() - DASH = auto() - PLUS = auto() - COLON = auto() - DOTCOLON = auto() - DCOLON = auto() - DQMARK = auto() - SEMICOLON = auto() - STAR = auto() - BACKSLASH = auto() - SLASH = auto() - LT = auto() - LTE = auto() - GT = auto() - GTE = auto() - NOT = auto() - EQ = auto() - NEQ = auto() - NULLSAFE_EQ = auto() - COLON_EQ = auto() - AND = auto() - OR = auto() - AMP = auto() - DPIPE = auto() - PIPE_GT = auto() - PIPE = auto() - PIPE_SLASH = auto() - DPIPE_SLASH = auto() - CARET = auto() - CARET_AT = auto() - TILDA = auto() - ARROW = auto() - DARROW = auto() - FARROW = auto() - HASH = auto() - HASH_ARROW = auto() - DHASH_ARROW = auto() - LR_ARROW = auto() - DAT = auto() - LT_AT = auto() - AT_GT = auto() - DOLLAR = auto() - PARAMETER = auto() - SESSION_PARAMETER = auto() - DAMP = auto() - XOR = auto() - DSTAR = auto() - - URI_START = auto() - - BLOCK_START = auto() - BLOCK_END = auto() - - SPACE = auto() - BREAK = auto() - - STRING = auto() - NUMBER = auto() - IDENTIFIER = auto() - DATABASE = auto() - COLUMN = auto() - COLUMN_DEF = auto() - SCHEMA = auto() - TABLE = auto() - WAREHOUSE = auto() - STAGE = auto() - STREAMLIT = auto() - VAR = auto() - BIT_STRING = auto() - HEX_STRING = auto() - BYTE_STRING = auto() - NATIONAL_STRING = auto() - RAW_STRING = auto() - HEREDOC_STRING = auto() - UNICODE_STRING = auto() - - # types - BIT = auto() - BOOLEAN = auto() - TINYINT = auto() - UTINYINT = auto() - SMALLINT = auto() - USMALLINT = auto() - MEDIUMINT = auto() - UMEDIUMINT = auto() - INT = auto() - UINT = auto() - BIGINT = auto() - UBIGINT = auto() - INT128 = auto() - UINT128 = auto() - INT256 = auto() - UINT256 = auto() - FLOAT = auto() - DOUBLE = auto() - UDOUBLE = auto() - DECIMAL = auto() - DECIMAL32 = auto() - DECIMAL64 = auto() - DECIMAL128 = auto() - DECIMAL256 = auto() - UDECIMAL = auto() - BIGDECIMAL = auto() - CHAR = auto() - NCHAR = auto() - VARCHAR = auto() - NVARCHAR = auto() - BPCHAR = auto() - TEXT = auto() - MEDIUMTEXT = auto() - LONGTEXT = auto() - BLOB = auto() - MEDIUMBLOB = auto() - LONGBLOB = auto() - TINYBLOB = auto() - TINYTEXT = auto() - NAME = auto() - BINARY = auto() - VARBINARY = auto() - JSON = auto() - JSONB = auto() - TIME = auto() - TIMETZ = auto() - TIMESTAMP = auto() - TIMESTAMPTZ = auto() - TIMESTAMPLTZ = auto() - TIMESTAMPNTZ = auto() - TIMESTAMP_S = auto() - TIMESTAMP_MS = auto() - TIMESTAMP_NS = auto() - DATETIME = auto() - DATETIME2 = auto() - DATETIME64 = auto() - SMALLDATETIME = auto() - DATE = auto() - DATE32 = auto() - INT4RANGE = auto() - INT4MULTIRANGE = auto() - INT8RANGE = auto() - INT8MULTIRANGE = auto() - NUMRANGE = auto() - NUMMULTIRANGE = auto() - TSRANGE = auto() - TSMULTIRANGE = auto() - TSTZRANGE = auto() - TSTZMULTIRANGE = auto() - DATERANGE = auto() - DATEMULTIRANGE = auto() - UUID = auto() - GEOGRAPHY = auto() - NULLABLE = auto() - GEOMETRY = auto() - POINT = auto() - RING = auto() - LINESTRING = auto() - MULTILINESTRING = auto() - POLYGON = auto() - MULTIPOLYGON = auto() - HLLSKETCH = auto() - HSTORE = auto() - SUPER = auto() - SERIAL = auto() - SMALLSERIAL = auto() - BIGSERIAL = auto() - XML = auto() - YEAR = auto() - USERDEFINED = auto() - MONEY = auto() - SMALLMONEY = auto() - ROWVERSION = auto() - IMAGE = auto() - VARIANT = auto() - OBJECT = auto() - INET = auto() - IPADDRESS = auto() - IPPREFIX = auto() - IPV4 = auto() - IPV6 = auto() - ENUM = auto() - ENUM8 = auto() - ENUM16 = auto() - FIXEDSTRING = auto() - LOWCARDINALITY = auto() - NESTED = auto() - AGGREGATEFUNCTION = auto() - SIMPLEAGGREGATEFUNCTION = auto() - TDIGEST = auto() - UNKNOWN = auto() - VECTOR = auto() - DYNAMIC = auto() - VOID = auto() - - # keywords - ALIAS = auto() - ALTER = auto() - ALWAYS = auto() - ALL = auto() - ANTI = auto() - ANY = auto() - APPLY = auto() - ARRAY = auto() - ASC = auto() - ASOF = auto() - ATTACH = auto() - AUTO_INCREMENT = auto() - BEGIN = auto() - BETWEEN = auto() - BULK_COLLECT_INTO = auto() - CACHE = auto() - CASE = auto() - CHARACTER_SET = auto() - CLUSTER_BY = auto() - COLLATE = auto() - COMMAND = auto() - COMMENT = auto() - COMMIT = auto() - CONNECT_BY = auto() - CONSTRAINT = auto() - COPY = auto() - CREATE = auto() - CROSS = auto() - CUBE = auto() - CURRENT_DATE = auto() - CURRENT_DATETIME = auto() - CURRENT_SCHEMA = auto() - CURRENT_TIME = auto() - CURRENT_TIMESTAMP = auto() - CURRENT_USER = auto() - DECLARE = auto() - DEFAULT = auto() - DELETE = auto() - DESC = auto() - DESCRIBE = auto() - DETACH = auto() - DICTIONARY = auto() - DISTINCT = auto() - DISTRIBUTE_BY = auto() - DIV = auto() - DROP = auto() - ELSE = auto() - END = auto() - ESCAPE = auto() - EXCEPT = auto() - EXECUTE = auto() - EXISTS = auto() - FALSE = auto() - FETCH = auto() - FILE_FORMAT = auto() - FILTER = auto() - FINAL = auto() - FIRST = auto() - FOR = auto() - FORCE = auto() - FOREIGN_KEY = auto() - FORMAT = auto() - FROM = auto() - FULL = auto() - FUNCTION = auto() - GET = auto() - GLOB = auto() - GLOBAL = auto() - GRANT = auto() - GROUP_BY = auto() - GROUPING_SETS = auto() - HAVING = auto() - HINT = auto() - IGNORE = auto() - ILIKE = auto() - ILIKE_ANY = auto() - IN = auto() - INDEX = auto() - INNER = auto() - INSERT = auto() - INTERSECT = auto() - INTERVAL = auto() - INTO = auto() - INTRODUCER = auto() - IRLIKE = auto() - IS = auto() - ISNULL = auto() - JOIN = auto() - JOIN_MARKER = auto() - KEEP = auto() - KEY = auto() - KILL = auto() - LANGUAGE = auto() - LATERAL = auto() - LEFT = auto() - LIKE = auto() - LIKE_ANY = auto() - LIMIT = auto() - LIST = auto() - LOAD = auto() - LOCK = auto() - MAP = auto() - MATCH_CONDITION = auto() - MATCH_RECOGNIZE = auto() - MEMBER_OF = auto() - MERGE = auto() - MOD = auto() - MODEL = auto() - NATURAL = auto() - NEXT = auto() - NOTHING = auto() - NOTNULL = auto() - NULL = auto() - OBJECT_IDENTIFIER = auto() - OFFSET = auto() - ON = auto() - ONLY = auto() - OPERATOR = auto() - ORDER_BY = auto() - ORDER_SIBLINGS_BY = auto() - ORDERED = auto() - ORDINALITY = auto() - OUTER = auto() - OVER = auto() - OVERLAPS = auto() - OVERWRITE = auto() - PARTITION = auto() - PARTITION_BY = auto() - PERCENT = auto() - PIVOT = auto() - PLACEHOLDER = auto() - POSITIONAL = auto() - PRAGMA = auto() - PREWHERE = auto() - PRIMARY_KEY = auto() - PROCEDURE = auto() - PROPERTIES = auto() - PSEUDO_TYPE = auto() - PUT = auto() - QUALIFY = auto() - QUOTE = auto() - RANGE = auto() - RECURSIVE = auto() - REFRESH = auto() - RENAME = auto() - REPLACE = auto() - RETURNING = auto() - REFERENCES = auto() - RIGHT = auto() - RLIKE = auto() - ROLLBACK = auto() - ROLLUP = auto() - ROW = auto() - ROWS = auto() - SELECT = auto() - SEMI = auto() - SEPARATOR = auto() - SEQUENCE = auto() - SERDE_PROPERTIES = auto() - SET = auto() - SETTINGS = auto() - SHOW = auto() - SIMILAR_TO = auto() - SOME = auto() - SORT_BY = auto() - START_WITH = auto() - STORAGE_INTEGRATION = auto() - STRAIGHT_JOIN = auto() - STRUCT = auto() - SUMMARIZE = auto() - TABLE_SAMPLE = auto() - TAG = auto() - TEMPORARY = auto() - TOP = auto() - THEN = auto() - TRUE = auto() - TRUNCATE = auto() - UNCACHE = auto() - UNION = auto() - UNNEST = auto() - UNPIVOT = auto() - UPDATE = auto() - USE = auto() - USING = auto() - VALUES = auto() - VIEW = auto() - VOLATILE = auto() - WHEN = auto() - WHERE = auto() - WINDOW = auto() - WITH = auto() - UNIQUE = auto() - VERSION_SNAPSHOT = auto() - TIMESTAMP_SNAPSHOT = auto() - OPTION = auto() - SINK = auto() - SOURCE = auto() - ANALYZE = auto() - NAMESPACE = auto() - EXPORT = auto() - - -_ALL_TOKEN_TYPES = list(TokenType) -_TOKEN_TYPE_TO_INDEX = {token_type: i for i, token_type in enumerate(_ALL_TOKEN_TYPES)} - - -class Token: - __slots__ = ("token_type", "text", "line", "col", "start", "end", "comments") - - @classmethod - def number(cls, number: int) -> Token: - """Returns a NUMBER token with `number` as its text.""" - return cls(TokenType.NUMBER, str(number)) - - @classmethod - def string(cls, string: str) -> Token: - """Returns a STRING token with `string` as its text.""" - return cls(TokenType.STRING, string) - - @classmethod - def identifier(cls, identifier: str) -> Token: - """Returns an IDENTIFIER token with `identifier` as its text.""" - return cls(TokenType.IDENTIFIER, identifier) - - @classmethod - def var(cls, var: str) -> Token: - """Returns an VAR token with `var` as its text.""" - return cls(TokenType.VAR, var) - - def __init__( - self, - token_type: TokenType, - text: str, - line: int = 1, - col: int = 1, - start: int = 0, - end: int = 0, - comments: t.Optional[t.List[str]] = None, - ) -> None: - """Token initializer. - - Args: - token_type: The TokenType Enum. - text: The text of the token. - line: The line that the token ends on. - col: The column that the token ends on. - start: The start index of the token. - end: The ending index of the token. - comments: The comments to attach to the token. - """ - self.token_type = token_type - self.text = text - self.line = line - self.col = col - self.start = start - self.end = end - self.comments = [] if comments is None else comments - - def __repr__(self) -> str: - attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) - return f"" - - -class _Tokenizer(type): - def __new__(cls, clsname, bases, attrs): - klass = super().__new__(cls, clsname, bases, attrs) - - def _convert_quotes(arr: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]: - return dict( - (item, item) if isinstance(item, str) else (item[0], item[1]) for item in arr - ) - - def _quotes_to_format( - token_type: TokenType, arr: t.List[str | t.Tuple[str, str]] - ) -> t.Dict[str, t.Tuple[str, TokenType]]: - return {k: (v, token_type) for k, v in _convert_quotes(arr).items()} - - klass._QUOTES = _convert_quotes(klass.QUOTES) - klass._IDENTIFIERS = _convert_quotes(klass.IDENTIFIERS) - - klass._FORMAT_STRINGS = { - **{ - p + s: (e, TokenType.NATIONAL_STRING) - for s, e in klass._QUOTES.items() - for p in ("n", "N") - }, - **_quotes_to_format(TokenType.BIT_STRING, klass.BIT_STRINGS), - **_quotes_to_format(TokenType.BYTE_STRING, klass.BYTE_STRINGS), - **_quotes_to_format(TokenType.HEX_STRING, klass.HEX_STRINGS), - **_quotes_to_format(TokenType.RAW_STRING, klass.RAW_STRINGS), - **_quotes_to_format(TokenType.HEREDOC_STRING, klass.HEREDOC_STRINGS), - **_quotes_to_format(TokenType.UNICODE_STRING, klass.UNICODE_STRINGS), - } - - klass._STRING_ESCAPES = set(klass.STRING_ESCAPES) - klass._IDENTIFIER_ESCAPES = set(klass.IDENTIFIER_ESCAPES) - klass._COMMENTS = { - **dict( - (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) - for comment in klass.COMMENTS - ), - "{#": "#}", # Ensure Jinja comments are tokenized correctly in all dialects - } - if klass.HINT_START in klass.KEYWORDS: - klass._COMMENTS[klass.HINT_START] = "*/" - - klass._KEYWORD_TRIE = new_trie( - key.upper() - for key in ( - *klass.KEYWORDS, - *klass._COMMENTS, - *klass._QUOTES, - *klass._FORMAT_STRINGS, - ) - if " " in key or any(single in key for single in klass.SINGLE_TOKENS) - ) - - if USE_RS_TOKENIZER: - settings = RsTokenizerSettings( - white_space={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.WHITE_SPACE.items()}, - single_tokens={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.SINGLE_TOKENS.items()}, - keywords={k: _TOKEN_TYPE_TO_INDEX[v] for k, v in klass.KEYWORDS.items()}, - numeric_literals=klass.NUMERIC_LITERALS, - identifiers=klass._IDENTIFIERS, - identifier_escapes=klass._IDENTIFIER_ESCAPES, - string_escapes=klass._STRING_ESCAPES, - quotes=klass._QUOTES, - format_strings={ - k: (v1, _TOKEN_TYPE_TO_INDEX[v2]) - for k, (v1, v2) in klass._FORMAT_STRINGS.items() - }, - has_bit_strings=bool(klass.BIT_STRINGS), - has_hex_strings=bool(klass.HEX_STRINGS), - comments=klass._COMMENTS, - var_single_tokens=klass.VAR_SINGLE_TOKENS, - commands={_TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMANDS}, - command_prefix_tokens={ - _TOKEN_TYPE_TO_INDEX[v] for v in klass.COMMAND_PREFIX_TOKENS - }, - heredoc_tag_is_identifier=klass.HEREDOC_TAG_IS_IDENTIFIER, - string_escapes_allowed_in_raw_strings=klass.STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS, - nested_comments=klass.NESTED_COMMENTS, - hint_start=klass.HINT_START, - tokens_preceding_hint={ - _TOKEN_TYPE_TO_INDEX[v] for v in klass.TOKENS_PRECEDING_HINT - }, - ) - token_types = RsTokenTypeSettings( - bit_string=_TOKEN_TYPE_TO_INDEX[TokenType.BIT_STRING], - break_=_TOKEN_TYPE_TO_INDEX[TokenType.BREAK], - dcolon=_TOKEN_TYPE_TO_INDEX[TokenType.DCOLON], - heredoc_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEREDOC_STRING], - raw_string=_TOKEN_TYPE_TO_INDEX[TokenType.RAW_STRING], - hex_string=_TOKEN_TYPE_TO_INDEX[TokenType.HEX_STRING], - identifier=_TOKEN_TYPE_TO_INDEX[TokenType.IDENTIFIER], - number=_TOKEN_TYPE_TO_INDEX[TokenType.NUMBER], - parameter=_TOKEN_TYPE_TO_INDEX[TokenType.PARAMETER], - semicolon=_TOKEN_TYPE_TO_INDEX[TokenType.SEMICOLON], - string=_TOKEN_TYPE_TO_INDEX[TokenType.STRING], - var=_TOKEN_TYPE_TO_INDEX[TokenType.VAR], - heredoc_string_alternative=_TOKEN_TYPE_TO_INDEX[klass.HEREDOC_STRING_ALTERNATIVE], - hint=_TOKEN_TYPE_TO_INDEX[TokenType.HINT], - ) - klass._RS_TOKENIZER = RsTokenizer(settings, token_types) - else: - klass._RS_TOKENIZER = None - - return klass - - -class Tokenizer(metaclass=_Tokenizer): - SINGLE_TOKENS = { - "(": TokenType.L_PAREN, - ")": TokenType.R_PAREN, - "[": TokenType.L_BRACKET, - "]": TokenType.R_BRACKET, - "{": TokenType.L_BRACE, - "}": TokenType.R_BRACE, - "&": TokenType.AMP, - "^": TokenType.CARET, - ":": TokenType.COLON, - ",": TokenType.COMMA, - ".": TokenType.DOT, - "-": TokenType.DASH, - "=": TokenType.EQ, - ">": TokenType.GT, - "<": TokenType.LT, - "%": TokenType.MOD, - "!": TokenType.NOT, - "|": TokenType.PIPE, - "+": TokenType.PLUS, - ";": TokenType.SEMICOLON, - "/": TokenType.SLASH, - "\\": TokenType.BACKSLASH, - "*": TokenType.STAR, - "~": TokenType.TILDA, - "?": TokenType.PLACEHOLDER, - "@": TokenType.PARAMETER, - "#": TokenType.HASH, - # Used for breaking a var like x'y' but nothing else the token type doesn't matter - "'": TokenType.UNKNOWN, - "`": TokenType.UNKNOWN, - '"': TokenType.UNKNOWN, - } - - BIT_STRINGS: t.List[str | t.Tuple[str, str]] = [] - BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = [] - HEX_STRINGS: t.List[str | t.Tuple[str, str]] = [] - RAW_STRINGS: t.List[str | t.Tuple[str, str]] = [] - HEREDOC_STRINGS: t.List[str | t.Tuple[str, str]] = [] - UNICODE_STRINGS: t.List[str | t.Tuple[str, str]] = [] - IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"'] - QUOTES: t.List[t.Tuple[str, str] | str] = ["'"] - STRING_ESCAPES = ["'"] - VAR_SINGLE_TOKENS: t.Set[str] = set() - - # The strings in this list can always be used as escapes, regardless of the surrounding - # identifier delimiters. By default, the closing delimiter is assumed to also act as an - # identifier escape, e.g. if we use double-quotes, then they also act as escapes: "x""" - IDENTIFIER_ESCAPES: t.List[str] = [] - - # Whether the heredoc tags follow the same lexical rules as unquoted identifiers - HEREDOC_TAG_IS_IDENTIFIER = False - - # Token that we'll generate as a fallback if the heredoc prefix doesn't correspond to a heredoc - HEREDOC_STRING_ALTERNATIVE = TokenType.VAR - - # Whether string escape characters function as such when placed within raw strings - STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = True - - NESTED_COMMENTS = True - - HINT_START = "/*+" - - TOKENS_PRECEDING_HINT = {TokenType.SELECT, TokenType.INSERT, TokenType.UPDATE, TokenType.DELETE} - - # Autofilled - _COMMENTS: t.Dict[str, str] = {} - _FORMAT_STRINGS: t.Dict[str, t.Tuple[str, TokenType]] = {} - _IDENTIFIERS: t.Dict[str, str] = {} - _IDENTIFIER_ESCAPES: t.Set[str] = set() - _QUOTES: t.Dict[str, str] = {} - _STRING_ESCAPES: t.Set[str] = set() - _KEYWORD_TRIE: t.Dict = {} - _RS_TOKENIZER: t.Optional[t.Any] = None - - KEYWORDS: t.Dict[str, TokenType] = { - **{f"{{%{postfix}": TokenType.BLOCK_START for postfix in ("", "+", "-")}, - **{f"{prefix}%}}": TokenType.BLOCK_END for prefix in ("", "+", "-")}, - **{f"{{{{{postfix}": TokenType.BLOCK_START for postfix in ("+", "-")}, - **{f"{prefix}}}}}": TokenType.BLOCK_END for prefix in ("+", "-")}, - HINT_START: TokenType.HINT, - "==": TokenType.EQ, - "::": TokenType.DCOLON, - "||": TokenType.DPIPE, - "|>": TokenType.PIPE_GT, - ">=": TokenType.GTE, - "<=": TokenType.LTE, - "<>": TokenType.NEQ, - "!=": TokenType.NEQ, - ":=": TokenType.COLON_EQ, - "<=>": TokenType.NULLSAFE_EQ, - "->": TokenType.ARROW, - "->>": TokenType.DARROW, - "=>": TokenType.FARROW, - "#>": TokenType.HASH_ARROW, - "#>>": TokenType.DHASH_ARROW, - "<->": TokenType.LR_ARROW, - "&&": TokenType.DAMP, - "??": TokenType.DQMARK, - "~~~": TokenType.GLOB, - "~~": TokenType.LIKE, - "~~*": TokenType.ILIKE, - "~*": TokenType.IRLIKE, - "ALL": TokenType.ALL, - "ALWAYS": TokenType.ALWAYS, - "AND": TokenType.AND, - "ANTI": TokenType.ANTI, - "ANY": TokenType.ANY, - "ASC": TokenType.ASC, - "AS": TokenType.ALIAS, - "ASOF": TokenType.ASOF, - "AUTOINCREMENT": TokenType.AUTO_INCREMENT, - "AUTO_INCREMENT": TokenType.AUTO_INCREMENT, - "BEGIN": TokenType.BEGIN, - "BETWEEN": TokenType.BETWEEN, - "CACHE": TokenType.CACHE, - "UNCACHE": TokenType.UNCACHE, - "CASE": TokenType.CASE, - "CHARACTER SET": TokenType.CHARACTER_SET, - "CLUSTER BY": TokenType.CLUSTER_BY, - "COLLATE": TokenType.COLLATE, - "COLUMN": TokenType.COLUMN, - "COMMIT": TokenType.COMMIT, - "CONNECT BY": TokenType.CONNECT_BY, - "CONSTRAINT": TokenType.CONSTRAINT, - "COPY": TokenType.COPY, - "CREATE": TokenType.CREATE, - "CROSS": TokenType.CROSS, - "CUBE": TokenType.CUBE, - "CURRENT_DATE": TokenType.CURRENT_DATE, - "CURRENT_SCHEMA": TokenType.CURRENT_SCHEMA, - "CURRENT_TIME": TokenType.CURRENT_TIME, - "CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP, - "CURRENT_USER": TokenType.CURRENT_USER, - "DATABASE": TokenType.DATABASE, - "DEFAULT": TokenType.DEFAULT, - "DELETE": TokenType.DELETE, - "DESC": TokenType.DESC, - "DESCRIBE": TokenType.DESCRIBE, - "DISTINCT": TokenType.DISTINCT, - "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, - "DIV": TokenType.DIV, - "DROP": TokenType.DROP, - "ELSE": TokenType.ELSE, - "END": TokenType.END, - "ENUM": TokenType.ENUM, - "ESCAPE": TokenType.ESCAPE, - "EXCEPT": TokenType.EXCEPT, - "EXECUTE": TokenType.EXECUTE, - "EXISTS": TokenType.EXISTS, - "FALSE": TokenType.FALSE, - "FETCH": TokenType.FETCH, - "FILTER": TokenType.FILTER, - "FIRST": TokenType.FIRST, - "FULL": TokenType.FULL, - "FUNCTION": TokenType.FUNCTION, - "FOR": TokenType.FOR, - "FOREIGN KEY": TokenType.FOREIGN_KEY, - "FORMAT": TokenType.FORMAT, - "FROM": TokenType.FROM, - "GEOGRAPHY": TokenType.GEOGRAPHY, - "GEOMETRY": TokenType.GEOMETRY, - "GLOB": TokenType.GLOB, - "GROUP BY": TokenType.GROUP_BY, - "GROUPING SETS": TokenType.GROUPING_SETS, - "HAVING": TokenType.HAVING, - "ILIKE": TokenType.ILIKE, - "IN": TokenType.IN, - "INDEX": TokenType.INDEX, - "INET": TokenType.INET, - "INNER": TokenType.INNER, - "INSERT": TokenType.INSERT, - "INTERVAL": TokenType.INTERVAL, - "INTERSECT": TokenType.INTERSECT, - "INTO": TokenType.INTO, - "IS": TokenType.IS, - "ISNULL": TokenType.ISNULL, - "JOIN": TokenType.JOIN, - "KEEP": TokenType.KEEP, - "KILL": TokenType.KILL, - "LATERAL": TokenType.LATERAL, - "LEFT": TokenType.LEFT, - "LIKE": TokenType.LIKE, - "LIMIT": TokenType.LIMIT, - "LOAD": TokenType.LOAD, - "LOCK": TokenType.LOCK, - "MERGE": TokenType.MERGE, - "NAMESPACE": TokenType.NAMESPACE, - "NATURAL": TokenType.NATURAL, - "NEXT": TokenType.NEXT, - "NOT": TokenType.NOT, - "NOTNULL": TokenType.NOTNULL, - "NULL": TokenType.NULL, - "OBJECT": TokenType.OBJECT, - "OFFSET": TokenType.OFFSET, - "ON": TokenType.ON, - "OR": TokenType.OR, - "XOR": TokenType.XOR, - "ORDER BY": TokenType.ORDER_BY, - "ORDINALITY": TokenType.ORDINALITY, - "OUTER": TokenType.OUTER, - "OVER": TokenType.OVER, - "OVERLAPS": TokenType.OVERLAPS, - "OVERWRITE": TokenType.OVERWRITE, - "PARTITION": TokenType.PARTITION, - "PARTITION BY": TokenType.PARTITION_BY, - "PARTITIONED BY": TokenType.PARTITION_BY, - "PARTITIONED_BY": TokenType.PARTITION_BY, - "PERCENT": TokenType.PERCENT, - "PIVOT": TokenType.PIVOT, - "PRAGMA": TokenType.PRAGMA, - "PRIMARY KEY": TokenType.PRIMARY_KEY, - "PROCEDURE": TokenType.PROCEDURE, - "QUALIFY": TokenType.QUALIFY, - "RANGE": TokenType.RANGE, - "RECURSIVE": TokenType.RECURSIVE, - "REGEXP": TokenType.RLIKE, - "RENAME": TokenType.RENAME, - "REPLACE": TokenType.REPLACE, - "RETURNING": TokenType.RETURNING, - "REFERENCES": TokenType.REFERENCES, - "RIGHT": TokenType.RIGHT, - "RLIKE": TokenType.RLIKE, - "ROLLBACK": TokenType.ROLLBACK, - "ROLLUP": TokenType.ROLLUP, - "ROW": TokenType.ROW, - "ROWS": TokenType.ROWS, - "SCHEMA": TokenType.SCHEMA, - "SELECT": TokenType.SELECT, - "SEMI": TokenType.SEMI, - "SET": TokenType.SET, - "SETTINGS": TokenType.SETTINGS, - "SHOW": TokenType.SHOW, - "SIMILAR TO": TokenType.SIMILAR_TO, - "SOME": TokenType.SOME, - "SORT BY": TokenType.SORT_BY, - "START WITH": TokenType.START_WITH, - "STRAIGHT_JOIN": TokenType.STRAIGHT_JOIN, - "TABLE": TokenType.TABLE, - "TABLESAMPLE": TokenType.TABLE_SAMPLE, - "TEMP": TokenType.TEMPORARY, - "TEMPORARY": TokenType.TEMPORARY, - "THEN": TokenType.THEN, - "TRUE": TokenType.TRUE, - "TRUNCATE": TokenType.TRUNCATE, - "UNION": TokenType.UNION, - "UNKNOWN": TokenType.UNKNOWN, - "UNNEST": TokenType.UNNEST, - "UNPIVOT": TokenType.UNPIVOT, - "UPDATE": TokenType.UPDATE, - "USE": TokenType.USE, - "USING": TokenType.USING, - "UUID": TokenType.UUID, - "VALUES": TokenType.VALUES, - "VIEW": TokenType.VIEW, - "VOLATILE": TokenType.VOLATILE, - "WHEN": TokenType.WHEN, - "WHERE": TokenType.WHERE, - "WINDOW": TokenType.WINDOW, - "WITH": TokenType.WITH, - "APPLY": TokenType.APPLY, - "ARRAY": TokenType.ARRAY, - "BIT": TokenType.BIT, - "BOOL": TokenType.BOOLEAN, - "BOOLEAN": TokenType.BOOLEAN, - "BYTE": TokenType.TINYINT, - "MEDIUMINT": TokenType.MEDIUMINT, - "INT1": TokenType.TINYINT, - "TINYINT": TokenType.TINYINT, - "INT16": TokenType.SMALLINT, - "SHORT": TokenType.SMALLINT, - "SMALLINT": TokenType.SMALLINT, - "HUGEINT": TokenType.INT128, - "UHUGEINT": TokenType.UINT128, - "INT2": TokenType.SMALLINT, - "INTEGER": TokenType.INT, - "INT": TokenType.INT, - "INT4": TokenType.INT, - "INT32": TokenType.INT, - "INT64": TokenType.BIGINT, - "INT128": TokenType.INT128, - "INT256": TokenType.INT256, - "LONG": TokenType.BIGINT, - "BIGINT": TokenType.BIGINT, - "INT8": TokenType.TINYINT, - "UINT": TokenType.UINT, - "UINT128": TokenType.UINT128, - "UINT256": TokenType.UINT256, - "DEC": TokenType.DECIMAL, - "DECIMAL": TokenType.DECIMAL, - "DECIMAL32": TokenType.DECIMAL32, - "DECIMAL64": TokenType.DECIMAL64, - "DECIMAL128": TokenType.DECIMAL128, - "DECIMAL256": TokenType.DECIMAL256, - "BIGDECIMAL": TokenType.BIGDECIMAL, - "BIGNUMERIC": TokenType.BIGDECIMAL, - "LIST": TokenType.LIST, - "MAP": TokenType.MAP, - "NULLABLE": TokenType.NULLABLE, - "NUMBER": TokenType.DECIMAL, - "NUMERIC": TokenType.DECIMAL, - "FIXED": TokenType.DECIMAL, - "REAL": TokenType.FLOAT, - "FLOAT": TokenType.FLOAT, - "FLOAT4": TokenType.FLOAT, - "FLOAT8": TokenType.DOUBLE, - "DOUBLE": TokenType.DOUBLE, - "DOUBLE PRECISION": TokenType.DOUBLE, - "JSON": TokenType.JSON, - "JSONB": TokenType.JSONB, - "CHAR": TokenType.CHAR, - "CHARACTER": TokenType.CHAR, - "CHAR VARYING": TokenType.VARCHAR, - "CHARACTER VARYING": TokenType.VARCHAR, - "NCHAR": TokenType.NCHAR, - "VARCHAR": TokenType.VARCHAR, - "VARCHAR2": TokenType.VARCHAR, - "NVARCHAR": TokenType.NVARCHAR, - "NVARCHAR2": TokenType.NVARCHAR, - "BPCHAR": TokenType.BPCHAR, - "STR": TokenType.TEXT, - "STRING": TokenType.TEXT, - "TEXT": TokenType.TEXT, - "LONGTEXT": TokenType.LONGTEXT, - "MEDIUMTEXT": TokenType.MEDIUMTEXT, - "TINYTEXT": TokenType.TINYTEXT, - "CLOB": TokenType.TEXT, - "LONGVARCHAR": TokenType.TEXT, - "BINARY": TokenType.BINARY, - "BLOB": TokenType.VARBINARY, - "LONGBLOB": TokenType.LONGBLOB, - "MEDIUMBLOB": TokenType.MEDIUMBLOB, - "TINYBLOB": TokenType.TINYBLOB, - "BYTEA": TokenType.VARBINARY, - "VARBINARY": TokenType.VARBINARY, - "TIME": TokenType.TIME, - "TIMETZ": TokenType.TIMETZ, - "TIMESTAMP": TokenType.TIMESTAMP, - "TIMESTAMPTZ": TokenType.TIMESTAMPTZ, - "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, - "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, - "TIMESTAMPNTZ": TokenType.TIMESTAMPNTZ, - "TIMESTAMP_NTZ": TokenType.TIMESTAMPNTZ, - "DATE": TokenType.DATE, - "DATETIME": TokenType.DATETIME, - "INT4RANGE": TokenType.INT4RANGE, - "INT4MULTIRANGE": TokenType.INT4MULTIRANGE, - "INT8RANGE": TokenType.INT8RANGE, - "INT8MULTIRANGE": TokenType.INT8MULTIRANGE, - "NUMRANGE": TokenType.NUMRANGE, - "NUMMULTIRANGE": TokenType.NUMMULTIRANGE, - "TSRANGE": TokenType.TSRANGE, - "TSMULTIRANGE": TokenType.TSMULTIRANGE, - "TSTZRANGE": TokenType.TSTZRANGE, - "TSTZMULTIRANGE": TokenType.TSTZMULTIRANGE, - "DATERANGE": TokenType.DATERANGE, - "DATEMULTIRANGE": TokenType.DATEMULTIRANGE, - "UNIQUE": TokenType.UNIQUE, - "VECTOR": TokenType.VECTOR, - "STRUCT": TokenType.STRUCT, - "SEQUENCE": TokenType.SEQUENCE, - "VARIANT": TokenType.VARIANT, - "ALTER": TokenType.ALTER, - "ANALYZE": TokenType.ANALYZE, - "CALL": TokenType.COMMAND, - "COMMENT": TokenType.COMMENT, - "EXPLAIN": TokenType.COMMAND, - "GRANT": TokenType.GRANT, - "OPTIMIZE": TokenType.COMMAND, - "PREPARE": TokenType.COMMAND, - "VACUUM": TokenType.COMMAND, - "USER-DEFINED": TokenType.USERDEFINED, - "FOR VERSION": TokenType.VERSION_SNAPSHOT, - "FOR TIMESTAMP": TokenType.TIMESTAMP_SNAPSHOT, - } - - WHITE_SPACE: t.Dict[t.Optional[str], TokenType] = { - " ": TokenType.SPACE, - "\t": TokenType.SPACE, - "\n": TokenType.BREAK, - "\r": TokenType.BREAK, - } - - COMMANDS = { - TokenType.COMMAND, - TokenType.EXECUTE, - TokenType.FETCH, - TokenType.SHOW, - TokenType.RENAME, - } - - COMMAND_PREFIX_TOKENS = {TokenType.SEMICOLON, TokenType.BEGIN} - - # Handle numeric literals like in hive (3L = BIGINT) - NUMERIC_LITERALS: t.Dict[str, str] = {} - - COMMENTS = ["--", ("/*", "*/")] - - __slots__ = ( - "sql", - "size", - "tokens", - "dialect", - "use_rs_tokenizer", - "_start", - "_current", - "_line", - "_col", - "_comments", - "_char", - "_end", - "_peek", - "_prev_token_line", - "_rs_dialect_settings", - ) - - def __init__( - self, dialect: DialectType = None, use_rs_tokenizer: t.Optional[bool] = None - ) -> None: - from sqlglot.dialects import Dialect - - self.dialect = Dialect.get_or_raise(dialect) - - # initialize `use_rs_tokenizer`, and allow it to be overwritten per Tokenizer instance - self.use_rs_tokenizer = ( - use_rs_tokenizer if use_rs_tokenizer is not None else USE_RS_TOKENIZER - ) - - if self.use_rs_tokenizer: - self._rs_dialect_settings = RsTokenizerDialectSettings( - unescaped_sequences=self.dialect.UNESCAPED_SEQUENCES, - identifiers_can_start_with_digit=self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT, - numbers_can_be_underscore_separated=self.dialect.NUMBERS_CAN_BE_UNDERSCORE_SEPARATED, - ) - - self.reset() - - def reset(self) -> None: - self.sql = "" - self.size = 0 - self.tokens: t.List[Token] = [] - self._start = 0 - self._current = 0 - self._line = 1 - self._col = 0 - self._comments: t.List[str] = [] - - self._char = "" - self._end = False - self._peek = "" - self._prev_token_line = -1 - - def tokenize(self, sql: str) -> t.List[Token]: - """Returns a list of tokens corresponding to the SQL string `sql`.""" - if self.use_rs_tokenizer: - return self.tokenize_rs(sql) - - self.reset() - self.sql = sql - self.size = len(sql) - - try: - self._scan() - except Exception as e: - start = max(self._current - 50, 0) - end = min(self._current + 50, self.size - 1) - context = self.sql[start:end] - raise TokenError(f"Error tokenizing '{context}'") from e - - return self.tokens - - def _scan(self, until: t.Optional[t.Callable] = None) -> None: - while self.size and not self._end: - current = self._current - - # Skip spaces here rather than iteratively calling advance() for performance reasons - while current < self.size: - char = self.sql[current] - - if char.isspace() and (char == " " or char == "\t"): - current += 1 - else: - break - - offset = current - self._current if current > self._current else 1 - - self._start = current - self._advance(offset) - - if not self._char.isspace(): - if self._char.isdigit(): - self._scan_number() - elif self._char in self._IDENTIFIERS: - self._scan_identifier(self._IDENTIFIERS[self._char]) - else: - self._scan_keywords() - - if until and until(): - break - - if self.tokens and self._comments: - self.tokens[-1].comments.extend(self._comments) - - def _chars(self, size: int) -> str: - if size == 1: - return self._char - - start = self._current - 1 - end = start + size - - return self.sql[start:end] if end <= self.size else "" - - def _advance(self, i: int = 1, alnum: bool = False) -> None: - if self.WHITE_SPACE.get(self._char) is TokenType.BREAK: - # Ensures we don't count an extra line if we get a \r\n line break sequence - if not (self._char == "\r" and self._peek == "\n"): - self._col = i - self._line += 1 - else: - self._col += i - - self._current += i - self._end = self._current >= self.size - self._char = self.sql[self._current - 1] - self._peek = "" if self._end else self.sql[self._current] - - if alnum and self._char.isalnum(): - # Here we use local variables instead of attributes for better performance - _col = self._col - _current = self._current - _end = self._end - _peek = self._peek - - while _peek.isalnum(): - _col += 1 - _current += 1 - _end = _current >= self.size - _peek = "" if _end else self.sql[_current] - - self._col = _col - self._current = _current - self._end = _end - self._peek = _peek - self._char = self.sql[_current - 1] - - @property - def _text(self) -> str: - return self.sql[self._start : self._current] - - def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: - self._prev_token_line = self._line - - if self._comments and token_type == TokenType.SEMICOLON and self.tokens: - self.tokens[-1].comments.extend(self._comments) - self._comments = [] - - self.tokens.append( - Token( - token_type, - text=self._text if text is None else text, - line=self._line, - col=self._col, - start=self._start, - end=self._current - 1, - comments=self._comments, - ) - ) - self._comments = [] - - # If we have either a semicolon or a begin token before the command's token, we'll parse - # whatever follows the command's token as a string - if ( - token_type in self.COMMANDS - and self._peek != ";" - and (len(self.tokens) == 1 or self.tokens[-2].token_type in self.COMMAND_PREFIX_TOKENS) - ): - start = self._current - tokens = len(self.tokens) - self._scan(lambda: self._peek == ";") - self.tokens = self.tokens[:tokens] - text = self.sql[start : self._current].strip() - if text: - self._add(TokenType.STRING, text) - - def _scan_keywords(self) -> None: - size = 0 - word = None - chars = self._text - char = chars - prev_space = False - skip = False - trie = self._KEYWORD_TRIE - single_token = char in self.SINGLE_TOKENS - - while chars: - if skip: - result = TrieResult.PREFIX - else: - result, trie = in_trie(trie, char.upper()) - - if result == TrieResult.FAILED: - break - if result == TrieResult.EXISTS: - word = chars - - end = self._current + size - size += 1 - - if end < self.size: - char = self.sql[end] - single_token = single_token or char in self.SINGLE_TOKENS - is_space = char.isspace() - - if not is_space or not prev_space: - if is_space: - char = " " - chars += char - prev_space = is_space - skip = False - else: - skip = True - else: - char = "" - break - - if word: - if self._scan_string(word): - return - if self._scan_comment(word): - return - if prev_space or single_token or not char: - self._advance(size - 1) - word = word.upper() - self._add(self.KEYWORDS[word], text=word) - return - - if self._char in self.SINGLE_TOKENS: - self._add(self.SINGLE_TOKENS[self._char], text=self._char) - return - - self._scan_var() - - def _scan_comment(self, comment_start: str) -> bool: - if comment_start not in self._COMMENTS: - return False - - comment_start_line = self._line - comment_start_size = len(comment_start) - comment_end = self._COMMENTS[comment_start] - - if comment_end: - # Skip the comment's start delimiter - self._advance(comment_start_size) - - comment_count = 1 - comment_end_size = len(comment_end) - - while not self._end: - if self._chars(comment_end_size) == comment_end: - comment_count -= 1 - if not comment_count: - break - - self._advance(alnum=True) - - # Nested comments are allowed by some dialects, e.g. databricks, duckdb, postgres - if ( - self.NESTED_COMMENTS - and not self._end - and self._chars(comment_end_size) == comment_start - ): - self._advance(comment_start_size) - comment_count += 1 - - self._comments.append(self._text[comment_start_size : -comment_end_size + 1]) - self._advance(comment_end_size - 1) - else: - while not self._end and self.WHITE_SPACE.get(self._peek) is not TokenType.BREAK: - self._advance(alnum=True) - self._comments.append(self._text[comment_start_size:]) - - if ( - comment_start == self.HINT_START - and self.tokens - and self.tokens[-1].token_type in self.TOKENS_PRECEDING_HINT - ): - self._add(TokenType.HINT) - - # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. - # Multiple consecutive comments are preserved by appending them to the current comments list. - if comment_start_line == self._prev_token_line: - self.tokens[-1].comments.extend(self._comments) - self._comments = [] - self._prev_token_line = self._line - - return True - - def _scan_number(self) -> None: - if self._char == "0": - peek = self._peek.upper() - if peek == "B": - return self._scan_bits() if self.BIT_STRINGS else self._add(TokenType.NUMBER) - elif peek == "X": - return self._scan_hex() if self.HEX_STRINGS else self._add(TokenType.NUMBER) - - decimal = False - scientific = 0 - - while True: - if self._peek.isdigit(): - self._advance() - elif self._peek == "." and not decimal: - if self.tokens and self.tokens[-1].token_type == TokenType.PARAMETER: - return self._add(TokenType.NUMBER) - decimal = True - self._advance() - elif self._peek in ("-", "+") and scientific == 1: - scientific += 1 - self._advance() - elif self._peek.upper() == "E" and not scientific: - scientific += 1 - self._advance() - elif self._peek.isidentifier(): - number_text = self._text - literal = "" - - while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: - literal += self._peek - self._advance() - - token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal.upper(), "")) - - if token_type: - self._add(TokenType.NUMBER, number_text) - self._add(TokenType.DCOLON, "::") - return self._add(token_type, literal) - else: - replaced = literal.replace("_", "") - if self.dialect.NUMBERS_CAN_BE_UNDERSCORE_SEPARATED and replaced.isdigit(): - return self._add(TokenType.NUMBER, number_text + replaced) - if self.dialect.IDENTIFIERS_CAN_START_WITH_DIGIT: - return self._add(TokenType.VAR) - - self._advance(-len(literal)) - return self._add(TokenType.NUMBER, number_text) - else: - return self._add(TokenType.NUMBER) - - def _scan_bits(self) -> None: - self._advance() - value = self._extract_value() - try: - # If `value` can't be converted to a binary, fallback to tokenizing it as an identifier - int(value, 2) - self._add(TokenType.BIT_STRING, value[2:]) # Drop the 0b - except ValueError: - self._add(TokenType.IDENTIFIER) - - def _scan_hex(self) -> None: - self._advance() - value = self._extract_value() - try: - # If `value` can't be converted to a hex, fallback to tokenizing it as an identifier - int(value, 16) - self._add(TokenType.HEX_STRING, value[2:]) # Drop the 0x - except ValueError: - self._add(TokenType.IDENTIFIER) - - def _extract_value(self) -> str: - while True: - char = self._peek.strip() - if char and char not in self.SINGLE_TOKENS: - self._advance(alnum=True) - else: - break - - return self._text - - def _scan_string(self, start: str) -> bool: - base = None - token_type = TokenType.STRING - - if start in self._QUOTES: - end = self._QUOTES[start] - elif start in self._FORMAT_STRINGS: - end, token_type = self._FORMAT_STRINGS[start] - - if token_type == TokenType.HEX_STRING: - base = 16 - elif token_type == TokenType.BIT_STRING: - base = 2 - elif token_type == TokenType.HEREDOC_STRING: - self._advance() - - if self._char == end: - tag = "" - else: - tag = self._extract_string( - end, - raw_string=True, - raise_unmatched=not self.HEREDOC_TAG_IS_IDENTIFIER, - ) - - if tag and self.HEREDOC_TAG_IS_IDENTIFIER and (self._end or not tag.isidentifier()): - if not self._end: - self._advance(-1) - - self._advance(-len(tag)) - self._add(self.HEREDOC_STRING_ALTERNATIVE) - return True - - end = f"{start}{tag}{end}" - else: - return False - - self._advance(len(start)) - text = self._extract_string(end, raw_string=token_type == TokenType.RAW_STRING) - - if base: - try: - int(text, base) - except Exception: - raise TokenError( - f"Numeric string contains invalid characters from {self._line}:{self._start}" - ) - - self._add(token_type, text) - return True - - def _scan_identifier(self, identifier_end: str) -> None: - self._advance() - text = self._extract_string( - identifier_end, escapes=self._IDENTIFIER_ESCAPES | {identifier_end} - ) - self._add(TokenType.IDENTIFIER, text) - - def _scan_var(self) -> None: - while True: - char = self._peek.strip() - if char and (char in self.VAR_SINGLE_TOKENS or char not in self.SINGLE_TOKENS): - self._advance(alnum=True) - else: - break - - self._add( - TokenType.VAR - if self.tokens and self.tokens[-1].token_type == TokenType.PARAMETER - else self.KEYWORDS.get(self._text.upper(), TokenType.VAR) - ) - - def _extract_string( - self, - delimiter: str, - escapes: t.Optional[t.Set[str]] = None, - raw_string: bool = False, - raise_unmatched: bool = True, - ) -> str: - text = "" - delim_size = len(delimiter) - escapes = self._STRING_ESCAPES if escapes is None else escapes - - while True: - if ( - not raw_string - and self.dialect.UNESCAPED_SEQUENCES - and self._peek - and self._char in self.STRING_ESCAPES - ): - unescaped_sequence = self.dialect.UNESCAPED_SEQUENCES.get(self._char + self._peek) - if unescaped_sequence: - self._advance(2) - text += unescaped_sequence - continue - if ( - (self.STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS or not raw_string) - and self._char in escapes - and (self._peek == delimiter or self._peek in escapes) - and (self._char not in self._QUOTES or self._char == self._peek) - ): - if self._peek == delimiter: - text += self._peek - else: - text += self._char + self._peek - - if self._current + 1 < self.size: - self._advance(2) - else: - raise TokenError(f"Missing {delimiter} from {self._line}:{self._current}") - else: - if self._chars(delim_size) == delimiter: - if delim_size > 1: - self._advance(delim_size - 1) - break - - if self._end: - if not raise_unmatched: - return text + self._char - - raise TokenError(f"Missing {delimiter} from {self._line}:{self._start}") - - current = self._current - 1 - self._advance(alnum=True) - text += self.sql[current : self._current - 1] - - return text - - def tokenize_rs(self, sql: str) -> t.List[Token]: - if not self._RS_TOKENIZER: - raise SqlglotError("Rust tokenizer is not available") - - tokens, error_msg = self._RS_TOKENIZER.tokenize(sql, self._rs_dialect_settings) - for token in tokens: - token.token_type = _ALL_TOKEN_TYPES[token.token_type_index] - - # Setting this here so partial token lists can be inspected even if there is a failure - self.tokens = tokens - - if error_msg is not None: - raise TokenError(error_msg) - - return tokens diff --git a/altimate_packages/sqlglot/transforms.py b/altimate_packages/sqlglot/transforms.py deleted file mode 100644 index 815139d2c..000000000 --- a/altimate_packages/sqlglot/transforms.py +++ /dev/null @@ -1,1020 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlglot import expressions as exp -from sqlglot.errors import UnsupportedError -from sqlglot.helper import find_new_name, name_sequence - - -if t.TYPE_CHECKING: - from sqlglot._typing import E - from sqlglot.generator import Generator - - -def preprocess( - transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], -) -> t.Callable[[Generator, exp.Expression], str]: - """ - Creates a new transform by chaining a sequence of transformations and converts the resulting - expression to SQL, using either the "_sql" method corresponding to the resulting expression, - or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). - - Args: - transforms: sequence of transform functions. These will be called in order. - - Returns: - Function that can be used as a generator transform. - """ - - def _to_sql(self, expression: exp.Expression) -> str: - expression_type = type(expression) - - try: - expression = transforms[0](expression) - for transform in transforms[1:]: - expression = transform(expression) - except UnsupportedError as unsupported_error: - self.unsupported(str(unsupported_error)) - - _sql_handler = getattr(self, expression.key + "_sql", None) - if _sql_handler: - return _sql_handler(expression) - - transforms_handler = self.TRANSFORMS.get(type(expression)) - if transforms_handler: - if expression_type is type(expression): - if isinstance(expression, exp.Func): - return self.function_fallback_sql(expression) - - # Ensures we don't enter an infinite loop. This can happen when the original expression - # has the same type as the final expression and there's no _sql method available for it, - # because then it'd re-enter _to_sql. - raise ValueError( - f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." - ) - - return transforms_handler(self, expression) - - raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") - - return _to_sql - - -def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: - if isinstance(expression, exp.Select): - count = 0 - recursive_ctes = [] - - for unnest in expression.find_all(exp.Unnest): - if ( - not isinstance(unnest.parent, (exp.From, exp.Join)) - or len(unnest.expressions) != 1 - or not isinstance(unnest.expressions[0], exp.GenerateDateArray) - ): - continue - - generate_date_array = unnest.expressions[0] - start = generate_date_array.args.get("start") - end = generate_date_array.args.get("end") - step = generate_date_array.args.get("step") - - if not start or not end or not isinstance(step, exp.Interval): - continue - - alias = unnest.args.get("alias") - column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" - - start = exp.cast(start, "date") - date_add = exp.func( - "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") - ) - cast_date_add = exp.cast(date_add, "date") - - cte_name = "_generated_dates" + (f"_{count}" if count else "") - - base_query = exp.select(start.as_(column_name)) - recursive_query = ( - exp.select(cast_date_add) - .from_(cte_name) - .where(cast_date_add <= exp.cast(end, "date")) - ) - cte_query = base_query.union(recursive_query, distinct=False) - - generate_dates_query = exp.select(column_name).from_(cte_name) - unnest.replace(generate_dates_query.subquery(cte_name)) - - recursive_ctes.append( - exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) - ) - count += 1 - - if recursive_ctes: - with_expression = expression.args.get("with") or exp.With() - with_expression.set("recursive", True) - with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) - expression.set("with", with_expression) - - return expression - - -def unnest_generate_series(expression: exp.Expression) -> exp.Expression: - """Unnests GENERATE_SERIES or SEQUENCE table references.""" - this = expression.this - if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): - unnest = exp.Unnest(expressions=[this]) - if expression.alias: - return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) - - return unnest - - return expression - - -def unalias_group(expression: exp.Expression) -> exp.Expression: - """ - Replace references to select aliases in GROUP BY clauses. - - Example: - >>> import sqlglot - >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() - 'SELECT a AS b FROM x GROUP BY 1' - - Args: - expression: the expression that will be transformed. - - Returns: - The transformed expression. - """ - if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): - aliased_selects = { - e.alias: i - for i, e in enumerate(expression.parent.expressions, start=1) - if isinstance(e, exp.Alias) - } - - for group_by in expression.expressions: - if ( - isinstance(group_by, exp.Column) - and not group_by.table - and group_by.name in aliased_selects - ): - group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) - - return expression - - -def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: - """ - Convert SELECT DISTINCT ON statements to a subquery with a window function. - - This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. - - Args: - expression: the expression that will be transformed. - - Returns: - The transformed expression. - """ - if ( - isinstance(expression, exp.Select) - and expression.args.get("distinct") - and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple) - ): - row_number_window_alias = find_new_name(expression.named_selects, "_row_number") - - distinct_cols = expression.args["distinct"].pop().args["on"].expressions - window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) - - order = expression.args.get("order") - if order: - window.set("order", order.pop()) - else: - window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) - - window = exp.alias_(window, row_number_window_alias) - expression.select(window, copy=False) - - # We add aliases to the projections so that we can safely reference them in the outer query - new_selects = [] - taken_names = {row_number_window_alias} - for select in expression.selects[:-1]: - if select.is_star: - new_selects = [exp.Star()] - break - - if not isinstance(select, exp.Alias): - alias = find_new_name(taken_names, select.output_name or "_col") - quoted = select.this.args.get("quoted") if isinstance(select, exp.Column) else None - select = select.replace(exp.alias_(select, alias, quoted=quoted)) - - taken_names.add(select.output_name) - new_selects.append(select.args["alias"]) - - return ( - exp.select(*new_selects, copy=False) - .from_(expression.subquery("_t", copy=False), copy=False) - .where(exp.column(row_number_window_alias).eq(1), copy=False) - ) - - return expression - - -def eliminate_qualify(expression: exp.Expression) -> exp.Expression: - """ - Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. - - The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: - https://docs.snowflake.com/en/sql-reference/constructs/qualify - - Some dialects don't support window functions in the WHERE clause, so we need to include them as - projections in the subquery, in order to refer to them in the outer filter using aliases. Also, - if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, - otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a - newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the - corresponding expression to avoid creating invalid column references. - """ - if isinstance(expression, exp.Select) and expression.args.get("qualify"): - taken = set(expression.named_selects) - for select in expression.selects: - if not select.alias_or_name: - alias = find_new_name(taken, "_c") - select.replace(exp.alias_(select, alias)) - taken.add(alias) - - def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: - alias_or_name = select.alias_or_name - identifier = select.args.get("alias") or select.this - if isinstance(identifier, exp.Identifier): - return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) - return alias_or_name - - outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) - qualify_filters = expression.args["qualify"].pop().this - expression_by_alias = { - select.alias: select.this - for select in expression.selects - if isinstance(select, exp.Alias) - } - - select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) - for select_candidate in list(qualify_filters.find_all(select_candidates)): - if isinstance(select_candidate, exp.Window): - if expression_by_alias: - for column in select_candidate.find_all(exp.Column): - expr = expression_by_alias.get(column.name) - if expr: - column.replace(expr) - - alias = find_new_name(expression.named_selects, "_w") - expression.select(exp.alias_(select_candidate, alias), copy=False) - column = exp.column(alias) - - if isinstance(select_candidate.parent, exp.Qualify): - qualify_filters = column - else: - select_candidate.replace(column) - elif select_candidate.name not in expression.named_selects: - expression.select(select_candidate.copy(), copy=False) - - return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( - qualify_filters, copy=False - ) - - return expression - - -def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: - """ - Some dialects only allow the precision for parameterized types to be defined in the DDL and not in - other expressions. This transforms removes the precision from parameterized types in expressions. - """ - for node in expression.find_all(exp.DataType): - node.set( - "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] - ) - - return expression - - -def unqualify_unnest(expression: exp.Expression) -> exp.Expression: - """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" - from sqlglot.optimizer.scope import find_all_in_scope - - if isinstance(expression, exp.Select): - unnest_aliases = { - unnest.alias - for unnest in find_all_in_scope(expression, exp.Unnest) - if isinstance(unnest.parent, (exp.From, exp.Join)) - } - if unnest_aliases: - for column in expression.find_all(exp.Column): - leftmost_part = column.parts[0] - if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases: - leftmost_part.pop() - - return expression - - -def unnest_to_explode( - expression: exp.Expression, - unnest_using_arrays_zip: bool = True, -) -> exp.Expression: - """Convert cross join unnest into lateral view explode.""" - - def _unnest_zip_exprs( - u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool - ) -> t.List[exp.Expression]: - if has_multi_expr: - if not unnest_using_arrays_zip: - raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") - - # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions - zip_exprs: t.List[exp.Expression] = [ - exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) - ] - u.set("expressions", zip_exprs) - return zip_exprs - return unnest_exprs - - def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: - if u.args.get("offset"): - return exp.Posexplode - return exp.Inline if has_multi_expr else exp.Explode - - if isinstance(expression, exp.Select): - from_ = expression.args.get("from") - - if from_ and isinstance(from_.this, exp.Unnest): - unnest = from_.this - alias = unnest.args.get("alias") - exprs = unnest.expressions - has_multi_expr = len(exprs) > 1 - this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) - - unnest.replace( - exp.Table( - this=_udtf_type(unnest, has_multi_expr)( - this=this, - expressions=expressions, - ), - alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, - ) - ) - - joins = expression.args.get("joins") or [] - for join in list(joins): - join_expr = join.this - - is_lateral = isinstance(join_expr, exp.Lateral) - - unnest = join_expr.this if is_lateral else join_expr - - if isinstance(unnest, exp.Unnest): - if is_lateral: - alias = join_expr.args.get("alias") - else: - alias = unnest.args.get("alias") - exprs = unnest.expressions - # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here - has_multi_expr = len(exprs) > 1 - exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) - - joins.remove(join) - - alias_cols = alias.columns if alias else [] - - # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases - # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. - # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html - - if not has_multi_expr and len(alias_cols) not in (1, 2): - raise UnsupportedError( - "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" - ) - - for e, column in zip(exprs, alias_cols): - expression.append( - "laterals", - exp.Lateral( - this=_udtf_type(unnest, has_multi_expr)(this=e), - view=True, - alias=exp.TableAlias( - this=alias.this, # type: ignore - columns=alias_cols, - ), - ), - ) - - return expression - - -def explode_projection_to_unnest( - index_offset: int = 0, -) -> t.Callable[[exp.Expression], exp.Expression]: - """Convert explode/posexplode projections into unnests.""" - - def _explode_projection_to_unnest(expression: exp.Expression) -> exp.Expression: - if isinstance(expression, exp.Select): - from sqlglot.optimizer.scope import Scope - - taken_select_names = set(expression.named_selects) - taken_source_names = {name for name, _ in Scope(expression).references} - - def new_name(names: t.Set[str], name: str) -> str: - name = find_new_name(names, name) - names.add(name) - return name - - arrays: t.List[exp.Condition] = [] - series_alias = new_name(taken_select_names, "pos") - series = exp.alias_( - exp.Unnest( - expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] - ), - new_name(taken_source_names, "_u"), - table=[series_alias], - ) - - # we use list here because expression.selects is mutated inside the loop - for select in list(expression.selects): - explode = select.find(exp.Explode) - - if explode: - pos_alias = "" - explode_alias = "" - - if isinstance(select, exp.Alias): - explode_alias = select.args["alias"] - alias = select - elif isinstance(select, exp.Aliases): - pos_alias = select.aliases[0] - explode_alias = select.aliases[1] - alias = select.replace(exp.alias_(select.this, "", copy=False)) - else: - alias = select.replace(exp.alias_(select, "")) - explode = alias.find(exp.Explode) - assert explode - - is_posexplode = isinstance(explode, exp.Posexplode) - explode_arg = explode.this - - if isinstance(explode, exp.ExplodeOuter): - bracket = explode_arg[0] - bracket.set("safe", True) - bracket.set("offset", True) - explode_arg = exp.func( - "IF", - exp.func( - "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) - ).eq(0), - exp.array(bracket, copy=False), - explode_arg, - ) - - # This ensures that we won't use [POS]EXPLODE's argument as a new selection - if isinstance(explode_arg, exp.Column): - taken_select_names.add(explode_arg.output_name) - - unnest_source_alias = new_name(taken_source_names, "_u") - - if not explode_alias: - explode_alias = new_name(taken_select_names, "col") - - if is_posexplode: - pos_alias = new_name(taken_select_names, "pos") - - if not pos_alias: - pos_alias = new_name(taken_select_names, "pos") - - alias.set("alias", exp.to_identifier(explode_alias)) - - series_table_alias = series.args["alias"].this - column = exp.If( - this=exp.column(series_alias, table=series_table_alias).eq( - exp.column(pos_alias, table=unnest_source_alias) - ), - true=exp.column(explode_alias, table=unnest_source_alias), - ) - - explode.replace(column) - - if is_posexplode: - expressions = expression.expressions - expressions.insert( - expressions.index(alias) + 1, - exp.If( - this=exp.column(series_alias, table=series_table_alias).eq( - exp.column(pos_alias, table=unnest_source_alias) - ), - true=exp.column(pos_alias, table=unnest_source_alias), - ).as_(pos_alias), - ) - expression.set("expressions", expressions) - - if not arrays: - if expression.args.get("from"): - expression.join(series, copy=False, join_type="CROSS") - else: - expression.from_(series, copy=False) - - size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) - arrays.append(size) - - # trino doesn't support left join unnest with on conditions - # if it did, this would be much simpler - expression.join( - exp.alias_( - exp.Unnest( - expressions=[explode_arg.copy()], - offset=exp.to_identifier(pos_alias), - ), - unnest_source_alias, - table=[explode_alias], - ), - join_type="CROSS", - copy=False, - ) - - if index_offset != 1: - size = size - 1 - - expression.where( - exp.column(series_alias, table=series_table_alias) - .eq(exp.column(pos_alias, table=unnest_source_alias)) - .or_( - (exp.column(series_alias, table=series_table_alias) > size).and_( - exp.column(pos_alias, table=unnest_source_alias).eq(size) - ) - ), - copy=False, - ) - - if arrays: - end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) - - if index_offset != 1: - end = end - (1 - index_offset) - series.expressions[0].set("end", end) - - return expression - - return _explode_projection_to_unnest - - -def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: - """Transforms percentiles by adding a WITHIN GROUP clause to them.""" - if ( - isinstance(expression, exp.PERCENTILES) - and not isinstance(expression.parent, exp.WithinGroup) - and expression.expression - ): - column = expression.this.pop() - expression.set("this", expression.expression.pop()) - order = exp.Order(expressions=[exp.Ordered(this=column)]) - expression = exp.WithinGroup(this=expression, expression=order) - - return expression - - -def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: - """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" - if ( - isinstance(expression, exp.WithinGroup) - and isinstance(expression.this, exp.PERCENTILES) - and isinstance(expression.expression, exp.Order) - ): - quantile = expression.this.this - input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this - return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) - - return expression - - -def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: - """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" - if isinstance(expression, exp.With) and expression.recursive: - next_name = name_sequence("_c_") - - for cte in expression.expressions: - if not cte.args["alias"].columns: - query = cte.this - if isinstance(query, exp.SetOperation): - query = query.this - - cte.args["alias"].set( - "columns", - [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], - ) - - return expression - - -def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: - """Replace 'epoch' in casts by the equivalent date literal.""" - if ( - isinstance(expression, (exp.Cast, exp.TryCast)) - and expression.name.lower() == "epoch" - and expression.to.this in exp.DataType.TEMPORAL_TYPES - ): - expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) - - return expression - - -def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: - """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" - if isinstance(expression, exp.Select): - for join in expression.args.get("joins") or []: - on = join.args.get("on") - if on and join.kind in ("SEMI", "ANTI"): - subquery = exp.select("1").from_(join.this).where(on) - exists = exp.Exists(this=subquery) - if join.kind == "ANTI": - exists = exists.not_(copy=False) - - join.pop() - expression.where(exists, copy=False) - - return expression - - -def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: - """ - Converts a query with a FULL OUTER join to a union of identical queries that - use LEFT/RIGHT OUTER joins instead. This transformation currently only works - for queries that have a single FULL OUTER join. - """ - if isinstance(expression, exp.Select): - full_outer_joins = [ - (index, join) - for index, join in enumerate(expression.args.get("joins") or []) - if join.side == "FULL" - ] - - if len(full_outer_joins) == 1: - expression_copy = expression.copy() - expression.set("limit", None) - index, full_outer_join = full_outer_joins[0] - - tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) - join_conditions = full_outer_join.args.get("on") or exp.and_( - *[ - exp.column(col, tables[0]).eq(exp.column(col, tables[1])) - for col in full_outer_join.args.get("using") - ] - ) - - full_outer_join.set("side", "left") - anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) - expression_copy.args["joins"][index].set("side", "right") - expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) - expression_copy.args.pop("with", None) # remove CTEs from RIGHT side - expression.args.pop("order", None) # remove order by from LEFT side - - return exp.union(expression, expression_copy, copy=False, distinct=False) - - return expression - - -def move_ctes_to_top_level(expression: E) -> E: - """ - Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be - defined at the top-level, so for example queries like: - - SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq - - are invalid in those dialects. This transformation can be used to ensure all CTEs are - moved to the top level so that the final SQL code is valid from a syntax standpoint. - - TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). - """ - top_level_with = expression.args.get("with") - for inner_with in expression.find_all(exp.With): - if inner_with.parent is expression: - continue - - if not top_level_with: - top_level_with = inner_with.pop() - expression.set("with", top_level_with) - else: - if inner_with.recursive: - top_level_with.set("recursive", True) - - parent_cte = inner_with.find_ancestor(exp.CTE) - inner_with.pop() - - if parent_cte: - i = top_level_with.expressions.index(parent_cte) - top_level_with.expressions[i:i] = inner_with.expressions - top_level_with.set("expressions", top_level_with.expressions) - else: - top_level_with.set( - "expressions", top_level_with.expressions + inner_with.expressions - ) - - return expression - - -def ensure_bools(expression: exp.Expression) -> exp.Expression: - """Converts numeric values used in conditions into explicit boolean expressions.""" - from sqlglot.optimizer.canonicalize import ensure_bools - - def _ensure_bool(node: exp.Expression) -> None: - if ( - node.is_number - or ( - not isinstance(node, exp.SubqueryPredicate) - and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) - ) - or (isinstance(node, exp.Column) and not node.type) - ): - node.replace(node.neq(0)) - - for node in expression.walk(): - ensure_bools(node, _ensure_bool) - - return expression - - -def unqualify_columns(expression: exp.Expression) -> exp.Expression: - for column in expression.find_all(exp.Column): - # We only wanna pop off the table, db, catalog args - for part in column.parts[:-1]: - part.pop() - - return expression - - -def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: - assert isinstance(expression, exp.Create) - for constraint in expression.find_all(exp.UniqueColumnConstraint): - if constraint.parent: - constraint.parent.pop() - - return expression - - -def ctas_with_tmp_tables_to_create_tmp_view( - expression: exp.Expression, - tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, -) -> exp.Expression: - assert isinstance(expression, exp.Create) - properties = expression.args.get("properties") - temporary = any( - isinstance(prop, exp.TemporaryProperty) - for prop in (properties.expressions if properties else []) - ) - - # CTAS with temp tables map to CREATE TEMPORARY VIEW - if expression.kind == "TABLE" and temporary: - if expression.expression: - return exp.Create( - kind="TEMPORARY VIEW", - this=expression.this, - expression=expression.expression, - ) - return tmp_storage_provider(expression) - - return expression - - -def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: - """ - In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the - PARTITIONED BY value is an array of column names, they are transformed into a schema. - The corresponding columns are removed from the create statement. - """ - assert isinstance(expression, exp.Create) - has_schema = isinstance(expression.this, exp.Schema) - is_partitionable = expression.kind in {"TABLE", "VIEW"} - - if has_schema and is_partitionable: - prop = expression.find(exp.PartitionedByProperty) - if prop and prop.this and not isinstance(prop.this, exp.Schema): - schema = expression.this - columns = {v.name.upper() for v in prop.this.expressions} - partitions = [col for col in schema.expressions if col.name.upper() in columns] - schema.set("expressions", [e for e in schema.expressions if e not in partitions]) - prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) - expression.set("this", schema) - - return expression - - -def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: - """ - Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. - - Currently, SQLGlot uses the DATASOURCE format for Spark 3. - """ - assert isinstance(expression, exp.Create) - prop = expression.find(exp.PartitionedByProperty) - if ( - prop - and prop.this - and isinstance(prop.this, exp.Schema) - and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) - ): - prop_this = exp.Tuple( - expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] - ) - schema = expression.this - for e in prop.this.expressions: - schema.append("expressions", e) - prop.set("this", prop_this) - - return expression - - -def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: - """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" - if isinstance(expression, exp.Struct): - expression.set( - "expressions", - [ - exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e - for e in expression.expressions - ], - ) - - return expression - - -def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: - """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178 - - 1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax. - - 2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view. - - The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query. - - You cannot use the (+) operator to outer-join a table to itself, although self joins are valid. - - The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator. - - A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator. - - A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression. - - A WHERE condition cannot compare any column marked with the (+) operator with a subquery. - - -- example with WHERE - SELECT d.department_name, sum(e.salary) as total_salary - FROM departments d, employees e - WHERE e.department_id(+) = d.department_id - group by department_name - - -- example of left correlation in select - SELECT d.department_name, ( - SELECT SUM(e.salary) - FROM employees e - WHERE e.department_id(+) = d.department_id) AS total_salary - FROM departments d; - - -- example of left correlation in from - SELECT d.department_name, t.total_salary - FROM departments d, ( - SELECT SUM(e.salary) AS total_salary - FROM employees e - WHERE e.department_id(+) = d.department_id - ) t - """ - - from sqlglot.optimizer.scope import traverse_scope - from sqlglot.optimizer.normalize import normalize, normalized - from collections import defaultdict - - # we go in reverse to check the main query for left correlation - for scope in reversed(traverse_scope(expression)): - query = scope.expression - - where = query.args.get("where") - joins = query.args.get("joins", []) - - # knockout: we do not support left correlation (see point 2) - assert not scope.is_correlated_subquery, "Correlated queries are not supported" - - # nothing to do - we check it here after knockout above - if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)): - continue - - # make sure we have AND of ORs to have clear join terms - where = normalize(where.this) - assert normalized(where), "Cannot normalize JOIN predicates" - - joins_ons = defaultdict(list) # dict of {name: list of join AND conditions} - for cond in [where] if not isinstance(where, exp.And) else where.flatten(): - join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")] - - left_join_table = set(col.table for col in join_cols) - if not left_join_table: - continue - - assert not ( - len(left_join_table) > 1 - ), "Cannot combine JOIN predicates from different tables" - - for col in join_cols: - col.set("join_mark", False) - - joins_ons[left_join_table.pop()].append(cond) - - old_joins = {join.alias_or_name: join for join in joins} - new_joins = {} - query_from = query.args["from"] - - for table, predicates in joins_ons.items(): - join_what = old_joins.get(table, query_from).this.copy() - new_joins[join_what.alias_or_name] = exp.Join( - this=join_what, on=exp.and_(*predicates), kind="LEFT" - ) - - for p in predicates: - while isinstance(p.parent, exp.Paren): - p.parent.replace(p) - - parent = p.parent - p.pop() - if isinstance(parent, exp.Binary): - parent.replace(parent.right if parent.left is None else parent.left) - elif isinstance(parent, exp.Where): - parent.pop() - - if query_from.alias_or_name in new_joins: - only_old_joins = old_joins.keys() - new_joins.keys() - assert ( - len(only_old_joins) >= 1 - ), "Cannot determine which table to use in the new FROM clause" - - new_from_name = list(only_old_joins)[0] - query.set("from", exp.From(this=old_joins[new_from_name].this)) - - if new_joins: - for n, j in old_joins.items(): # preserve any other joins - if n not in new_joins and n != query.args["from"].name: - if not j.kind: - j.set("kind", "CROSS") - new_joins[n] = j - query.set("joins", list(new_joins.values())) - - return expression - - -def any_to_exists(expression: exp.Expression) -> exp.Expression: - """ - Transform ANY operator to Spark's EXISTS - - For example, - - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) - - Both ANY and EXISTS accept queries but currently only array expressions are supported for this - transformation - """ - if isinstance(expression, exp.Select): - for any_expr in expression.find_all(exp.Any): - this = any_expr.this - if isinstance(this, exp.Query): - continue - - binop = any_expr.parent - if isinstance(binop, exp.Binary): - lambda_arg = exp.to_identifier("x") - any_expr.replace(lambda_arg) - lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) - binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) - - return expression - - -def eliminate_window_clause(expression: exp.Expression) -> exp.Expression: - """Eliminates the `WINDOW` query clause by inling each named window.""" - if isinstance(expression, exp.Select) and expression.args.get("windows"): - from sqlglot.optimizer.scope import find_all_in_scope - - windows = expression.args["windows"] - expression.set("windows", None) - - window_expression: t.Dict[str, exp.Expression] = {} - - def _inline_inherited_window(window: exp.Expression) -> None: - inherited_window = window_expression.get(window.alias.lower()) - if not inherited_window: - return - - window.set("alias", None) - for key in ("partition_by", "order", "spec"): - arg = inherited_window.args.get(key) - if arg: - window.set(key, arg.copy()) - - for window in windows: - _inline_inherited_window(window) - window_expression[window.name.lower()] = window - - for window in find_all_in_scope(expression, exp.Window): - _inline_inherited_window(window) - - return expression diff --git a/altimate_packages/sqlglot/trie.py b/altimate_packages/sqlglot/trie.py deleted file mode 100644 index 59601ee35..000000000 --- a/altimate_packages/sqlglot/trie.py +++ /dev/null @@ -1,81 +0,0 @@ -import typing as t -from enum import Enum, auto - -key = t.Sequence[t.Hashable] - - -class TrieResult(Enum): - FAILED = auto() - PREFIX = auto() - EXISTS = auto() - - -def new_trie(keywords: t.Iterable[key], trie: t.Optional[t.Dict] = None) -> t.Dict: - """ - Creates a new trie out of a collection of keywords. - - The trie is represented as a sequence of nested dictionaries keyed by either single - character strings, or by 0, which is used to designate that a keyword is in the trie. - - Example: - >>> new_trie(["bla", "foo", "blab"]) - {'b': {'l': {'a': {0: True, 'b': {0: True}}}}, 'f': {'o': {'o': {0: True}}}} - - Args: - keywords: the keywords to create the trie from. - trie: a trie to mutate instead of creating a new one - - Returns: - The trie corresponding to `keywords`. - """ - trie = {} if trie is None else trie - - for key in keywords: - current = trie - for char in key: - current = current.setdefault(char, {}) - - current[0] = True - - return trie - - -def in_trie(trie: t.Dict, key: key) -> t.Tuple[TrieResult, t.Dict]: - """ - Checks whether a key is in a trie. - - Examples: - >>> in_trie(new_trie(["cat"]), "bob") - (, {'c': {'a': {'t': {0: True}}}}) - - >>> in_trie(new_trie(["cat"]), "ca") - (, {'t': {0: True}}) - - >>> in_trie(new_trie(["cat"]), "cat") - (, {0: True}) - - Args: - trie: The trie to be searched. - key: The target key. - - Returns: - A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point - where the search stops, and `value` is a TrieResult value that can be one of: - - - TrieResult.FAILED: the search was unsuccessful - - TrieResult.PREFIX: `value` is a prefix of a keyword in `trie` - - TrieResult.EXISTS: `key` exists in `trie` - """ - if not key: - return (TrieResult.FAILED, trie) - - current = trie - for char in key: - if char not in current: - return (TrieResult.FAILED, current) - current = current[char] - - if 0 in current: - return (TrieResult.EXISTS, current) - - return (TrieResult.PREFIX, current) diff --git a/dbt_cloud_integration.py b/dbt_cloud_integration.py deleted file mode 100644 index 12484ddc3..000000000 --- a/dbt_cloud_integration.py +++ /dev/null @@ -1,82 +0,0 @@ -from decimal import Decimal - -import os -import sys -import contextlib -from collections.abc import Iterable -from datetime import date, datetime, time -from typing import ( - Dict, - List, -) - - -@contextlib.contextmanager -def add_path(path): - sys.path.append(path) - try: - yield - finally: - sys.path.remove(path) - - -def validate_sql( - sql: str, - dialect: str, - models: List[Dict], -): - try: - ALTIMATE_PACKAGE_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "altimate_packages" - ) - with add_path(ALTIMATE_PACKAGE_PATH): - from altimate.validate_sql import validate_sql_from_models - - return validate_sql_from_models(sql, dialect, models) - except Exception as e: - raise Exception(str(e)) - - -def fetch_schema_from_sql(sql: str, dialect: str): - try: - ALTIMATE_PACKAGE_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "altimate_packages" - ) - with add_path(ALTIMATE_PACKAGE_PATH): - from altimate.fetch_schema import fetch_schema - - return fetch_schema(sql, dialect) - except Exception as e: - raise Exception(str(e)) - -def validate_whether_sql_has_columns(sql: str, dialect: str): - try: - ALTIMATE_PACKAGE_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "altimate_packages" - ) - with add_path(ALTIMATE_PACKAGE_PATH): - from altimate.fetch_schema import validate_whether_sql_has_columns - - return validate_whether_sql_has_columns(sql, dialect) - except Exception as e: - raise Exception(str(e)) - - -def to_dict(obj): - if isinstance(obj, str): - return obj - if isinstance(obj, Decimal): - return float(obj) - if isinstance(obj, (datetime, date, time)): - return obj.isoformat() - elif isinstance(obj, dict): - return dict((key, to_dict(val)) for key, val in obj.items()) - elif isinstance(obj, Iterable): - return [to_dict(val) for val in obj] - elif hasattr(obj, "__dict__"): - return to_dict(vars(obj)) - elif hasattr(obj, "__slots__"): - return to_dict( - dict((name, getattr(obj, name)) for name in getattr(obj, "__slots__")) - ) - return obj diff --git a/dbt_core_integration.py b/dbt_core_integration.py deleted file mode 100644 index 41dc97f7c..000000000 --- a/dbt_core_integration.py +++ /dev/null @@ -1,906 +0,0 @@ -try: - from dbt.version import __version__ as dbt_version -except Exception: - raise Exception("dbt not found. Please install dbt to use this extension.") - - -from decimal import Decimal -import os -import threading -import uuid -import sys -import contextlib -from collections import UserDict -from collections.abc import Iterable -from datetime import date, datetime, time -from copy import copy -from functools import lru_cache, partial -from hashlib import md5 -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Optional, - Tuple, - TypeVar, - Union, -) - -import agate -import json -from dbt.adapters.factory import get_adapter, register_adapter -from dbt.config.runtime import RuntimeConfig -from dbt.flags import set_from_args -from dbt.parser.manifest import ManifestLoader, process_node -from dbt.parser.sql import SqlBlockParser, SqlMacroParser -from dbt.task.sql import SqlCompileRunner, SqlExecuteRunner -from dbt.tracking import disable_tracking - -DBT_MAJOR_VER, DBT_MINOR_VER, DBT_PATCH_VER = ( - int(v) if v.isnumeric() else v for v in dbt_version.split(".") -) - -if DBT_MAJOR_VER >=1 and DBT_MINOR_VER >= 8: - from dbt.contracts.graph.manifest import Manifest # type: ignore - from dbt.contracts.graph.nodes import ManifestNode, CompiledNode # type: ignore - from dbt.artifacts.resources.v1.components import ColumnInfo # type: ignore - from dbt.artifacts.resources.types import NodeType # type: ignore - from dbt_common.events.functions import fire_event # type: ignore - from dbt.artifacts.schemas.manifest import WritableManifest # type: ignore -elif DBT_MAJOR_VER >= 1 and DBT_MINOR_VER > 3: - from dbt.contracts.graph.nodes import ColumnInfo, ManifestNode, CompiledNode # type: ignore - from dbt.node_types import NodeType # type: ignore - from dbt.contracts.graph.manifest import WritableManifest # type: ignore - from dbt.events.functions import fire_event # type: ignore -else: - from dbt.contracts.graph.compiled import ManifestNode, CompiledNode # type: ignore - from dbt.contracts.graph.parsed import ColumnInfo # type: ignore - from dbt.node_types import NodeType # type: ignore - from dbt.events.functions import fire_event # type: ignore - - -if TYPE_CHECKING: - # These imports are only used for type checking - from dbt.adapters.base import BaseRelation # type: ignore - if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8: - from dbt.adapters.contracts.connection import AdapterResponse - else: - from dbt.contracts.connection import AdapterResponse - -Primitive = Union[bool, str, float, None] -PrimitiveDict = Dict[str, Primitive] - -CACHE = {} -CACHE_VERSION = 1 -SQL_CACHE_SIZE = 1024 - -MANIFEST_ARTIFACT = "manifest.json" - -RAW_CODE = "raw_code" if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 3 else "raw_sql" -COMPILED_CODE = ( - "compiled_code" if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 3 else "compiled_sql" -) - -JINJA_CONTROL_SEQS = ["{{", "}}", "{%", "%}", "{#", "#}"] - -T = TypeVar("T") -REQUIRE_RESOURCE_NAMES_WITHOUT_SPACES = "REQUIRE_RESOURCE_NAMES_WITHOUT_SPACES" -DBT_DEBUG = "DBT_DEBUG" -DBT_DEFER = "DBT_DEFER" -DBT_STATE = "DBT_STATE" -DBT_FAVOR_STATE = "DBT_FAVOR_STATE" - -@contextlib.contextmanager -def add_path(path): - sys.path.append(path) - try: - yield - finally: - sys.path.remove(path) - - -def validate_sql( - sql: str, - dialect: str, - models: List[Dict], -): - try: - ALTIMATE_PACKAGE_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "altimate_packages" - ) - with add_path(ALTIMATE_PACKAGE_PATH): - from altimate.validate_sql import validate_sql_from_models - - return validate_sql_from_models(sql, dialect, models) - except Exception as e: - raise Exception(str(e)) - -def fetch_schema_from_sql(sql: str, dialect: str): - try: - ALTIMATE_PACKAGE_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "altimate_packages" - ) - with add_path(ALTIMATE_PACKAGE_PATH): - from altimate.fetch_schema import fetch_schema - - return fetch_schema(sql, dialect) - except Exception as e: - raise Exception(str(e)) - -def validate_whether_sql_has_columns(sql: str, dialect: str): - try: - ALTIMATE_PACKAGE_PATH = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "altimate_packages" - ) - with add_path(ALTIMATE_PACKAGE_PATH): - from altimate.fetch_schema import validate_whether_sql_has_columns - - return validate_whether_sql_has_columns(sql, dialect) - except Exception as e: - raise Exception(str(e)) - - -def to_dict(obj): - if isinstance(obj, agate.Table): - return { - "rows": [to_dict(row) for row in obj.rows], - "column_names": obj.column_names, - "column_types": list(map(lambda x: x.__class__.__name__, obj.column_types)), - } - if isinstance(obj, str): - return obj - if isinstance(obj, Decimal): - return float(obj) - if isinstance(obj, (datetime, date, time)): - return obj.isoformat() - elif isinstance(obj, dict): - return dict((key, to_dict(val)) for key, val in obj.items()) - elif isinstance(obj, Iterable): - return [to_dict(val) for val in obj] - elif hasattr(obj, "__dict__"): - return to_dict(vars(obj)) - elif hasattr(obj, "__slots__"): - return to_dict( - dict((name, getattr(obj, name)) for name in getattr(obj, "__slots__")) - ) - return obj - - -def has_jinja(query: str) -> bool: - """Utility to check for jinja prior to certain compilation procedures""" - return any(seq in query for seq in JINJA_CONTROL_SEQS) - - -def memoize_get_rendered(function): - """Custom memoization function for dbt-core jinja interface""" - - def wrapper( - string: str, - ctx: Dict[str, Any], - node: "ManifestNode" = None, - capture_macros: bool = False, - native: bool = False, - ): - v = md5(string.strip().encode("utf-8")).hexdigest() - v += "__" + str(CACHE_VERSION) - if capture_macros == True and node is not None: - if node.is_ephemeral: - return function(string, ctx, node, capture_macros, native) - v += "__" + node.unique_id - rv = CACHE.get(v) - if rv is not None: - return rv - else: - rv = function(string, ctx, node, capture_macros, native) - CACHE[v] = rv - return rv - - return wrapper - - -def default_profiles_dir(project_dir): - """Determines the directory where dbt will look for profiles.yml. - - When DBT_PROFILES_DIR is set: - - If it's an absolute path, use it as is - - If it's a relative path, resolve it relative to the project directory - This matches dbt core's behavior and other path handling in the codebase - (see https://github.com/AltimateAI/vscode-dbt-power-user/issues/1518) - - When DBT_PROFILES_DIR is not set: - - Look for profiles.yml in the project directory - - If not found, default to ~/.dbt/ - """ - if "DBT_PROFILES_DIR" in os.environ: - profiles_dir = os.path.expanduser(os.environ["DBT_PROFILES_DIR"]) - if os.path.isabs(profiles_dir): - return os.path.normpath(profiles_dir) - return os.path.normpath(os.path.join(project_dir, profiles_dir)) - project_profiles_file = os.path.normpath(os.path.join(project_dir, "profiles.yml")) - return ( - project_dir - if os.path.exists(project_profiles_file) - else os.path.join(os.path.expanduser("~"), ".dbt") - ) - - -def target_path(project_dir): - if "DBT_TARGET_PATH" in os.environ: - target_path = os.path.expanduser(os.environ["DBT_TARGET_PATH"]) - if os.path.isabs(target_path): - return os.path.normpath(target_path) - return os.path.normpath(os.path.join(project_dir, target_path)) - return None - - -def find_package_paths(project_directories): - def get_package_path(project_dir): - try: - project = DbtProject( - project_dir=project_dir, - profiles_dir=default_profiles_dir(project_dir), - target_path=target_path(project_dir), - ) - project.init_config() - packages_path = project.config.packages_install_path - if os.path.isabs(packages_path): - return os.path.normpath(packages_path) - return os.path.normpath(os.path.join(project_dir, packages_path)) - except Exception as e: - # We don't care about exceptions here, that is dealt with later when the project is loaded - pass - - return list(map(get_package_path, project_directories)) - - -# Performance hacks -# jinja.get_rendered = memoize_get_rendered(jinja.get_rendered) -disable_tracking() -fire_event = lambda e: None - - -class ConfigInterface: - """This mimic dbt-core args based interface for dbt-core - class instantiation""" - - def __init__( - self, - threads: Optional[int] = 1, - target: Optional[str] = None, - profiles_dir: Optional[str] = None, - project_dir: Optional[str] = None, - profile: Optional[str] = None, - target_path: Optional[str] = None, - defer: Optional[bool] = False, - state: Optional[str] = None, - favor_state: Optional[bool] = False, - # dict in 1.5.x onwards, json string before. - vars: Optional[Union[Dict[str, Any], str]] = {} if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 5 else "{}", - ): - self.threads = threads - self.target = target if target else os.environ.get("DBT_TARGET") - self.profiles_dir = profiles_dir - self.project_dir = project_dir - self.dependencies = [] - self.single_threaded = threads == 1 - self.quiet = True - self.profile = profile if profile else os.environ.get("DBT_PROFILE") - self.target_path = target_path - self.defer = defer - self.state = state - self.favor_state = favor_state - # dict in 1.5.x onwards, json string before. - if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 5: - self.vars = vars if vars else json.loads(os.environ.get("DBT_VARS", "{}")) - else: - self.vars = vars if vars else os.environ.get("DBT_VARS", "{}") - - def __str__(self): - return f"ConfigInterface(threads={self.threads}, target={self.target}, profiles_dir={self.profiles_dir}, project_dir={self.project_dir}, profile={self.profile}, target_path={self.target_path})" - - -class ManifestProxy(UserDict): - """Proxy for manifest dictionary (`flat_graph`), if we need mutation then we should - create a copy of the dict or interface with the dbt-core manifest object instead""" - - def _readonly(self, *args, **kwargs): - raise RuntimeError("Cannot modify ManifestProxy") - - __setitem__ = _readonly - __delitem__ = _readonly - pop = _readonly - popitem = _readonly - clear = _readonly - update = _readonly - setdefault = _readonly - - -class DbtAdapterExecutionResult: - """Interface for execution results, this keeps us 1 layer removed from dbt interfaces which may change""" - - def __init__( - self, - adapter_response: "AdapterResponse", - table: agate.Table, - raw_sql: str, - compiled_sql: str, - ) -> None: - self.adapter_response = adapter_response - self.table = table - self.raw_sql = raw_sql - self.compiled_sql = compiled_sql - - -class DbtAdapterCompilationResult: - """Interface for compilation results, this keeps us 1 layer removed from dbt interfaces which may change""" - - def __init__(self, raw_sql: str, compiled_sql: str, node: "ManifestNode") -> None: - self.raw_sql = raw_sql - self.compiled_sql = compiled_sql - self.node = node - - -class DbtProject: - """Container for a dbt project. The dbt attribute is the primary interface for - dbt-core. The adapter attribute is the primary interface for the dbt adapter""" - - def __init__( - self, - target_name: Optional[str] = None, - profiles_dir: Optional[str] = None, - project_dir: Optional[str] = None, - threads: Optional[int] = 1, - profile: Optional[str] = None, - target_path: Optional[str] = None, - defer_to_prod: bool = False, - manifest_path: Optional[str] = None, - favor_state: bool = False, - vars: Optional[Dict[str, Any]] = {}, - ): - self.args = ConfigInterface( - threads=threads, - target=target_name, - profiles_dir=profiles_dir, - project_dir=project_dir, - profile=profile, - target_path=target_path, - defer=defer_to_prod, - state=manifest_path, - favor_state=favor_state, - vars=vars, - ) - - # Utilities - self._sql_parser: Optional[SqlBlockParser] = None - self._macro_parser: Optional[SqlMacroParser] = None - self._sql_runner: Optional[SqlExecuteRunner] = None - self._sql_compiler: Optional[SqlCompileRunner] = None - - # Tracks internal state version - self._version: int = 1 - self.mutex = threading.Lock() - self.defer_to_prod = defer_to_prod - self.defer_to_prod_manifest_path = manifest_path - self.favor_state = favor_state - - def init_config(self): - if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8: - from dbt_common.context import set_invocation_context - from dbt.flags import get_flags - set_invocation_context(os.environ) - set_from_args(self.args, None) - # Copy over global_flags - for key, value in get_flags().__dict__.items(): - if key not in self.args.__dict__: - self.args.__dict__[key] = value - else: - set_from_args(self.args, self.args) - self.config = RuntimeConfig.from_args(self.args) - if hasattr(self.config, "source_paths"): - self.config.model_paths = self.config.source_paths - if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8: - from dbt.mp_context import get_mp_context - register_adapter(self.config, get_mp_context()) - else: - register_adapter(self.config) - - def init_project(self): - try: - self.init_config() - self.adapter = get_adapter(self.config) - self.adapter.connections.set_connection_name() - if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8: - from dbt.context.providers import generate_runtime_macro_context - self.adapter.set_macro_context_generator(generate_runtime_macro_context) - self.create_parser() - except Exception as e: - # reset project - self.config = None - self.dbt = None - raise Exception(str(e)) - - def parse_project(self) -> None: - try: - self.create_parser() - self.dbt.build_flat_graph() - except Exception as e: - # reset manifest - self.dbt = None - raise Exception(str(e)) - - self._sql_parser = None - self._macro_parser = None - self._sql_compiler = None - self._sql_runner = None - - def create_parser(self) -> None: - all_projects = self.config.load_dependencies() - # filter out project with value LoomRunnableConfig class type as those projects are dependency projects - # https://github.com/AltimateAI/vscode-dbt-power-user/issues/1224 - all_projects = {k: v for k, v in all_projects.items() if not v.__class__.__name__ == "LoomRunnableConfig"} - - project_parser = ManifestLoader( - self.config, - all_projects, - self.adapter.connections.set_query_header, - ) - self.dbt = project_parser.load() - project_parser.save_macros_to_adapter(self.adapter) - - def set_defer_config( - self, defer_to_prod: bool, manifest_path: str, favor_state: bool - ) -> None: - if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8: - self.args.defer = defer_to_prod - self.args.state = manifest_path - self.args.favor_state = favor_state - self.defer_to_prod = defer_to_prod - self.defer_to_prod_manifest_path = manifest_path - self.favor_state = favor_state - - @classmethod - def from_args(cls, args: ConfigInterface) -> "DbtProject": - """Instatiate the DbtProject directly from a ConfigInterface instance""" - return cls( - target=args.target, - profiles_dir=args.profiles_dir, - project_dir=args.project_dir, - threads=args.threads, - profile=args.profile, - target_path=args.target_path, - vars=args.vars, - ) - - @property - def sql_parser(self) -> SqlBlockParser: - """A dbt-core SQL parser capable of parsing and adding nodes to the manifest via `parse_remote` which will - also return the added node to the caller. Note that post-parsing this still typically requires calls to - `_process_nodes_for_ref` and `_process_sources_for_ref` from `dbt.parser.manifest` - """ - if self._sql_parser is None: - self._sql_parser = SqlBlockParser(self.config, self.dbt, self.config) - return self._sql_parser - - @property - def macro_parser(self) -> SqlMacroParser: - """A dbt-core macro parser""" - if self._macro_parser is None: - self._macro_parser = SqlMacroParser(self.config, self.dbt) - return self._macro_parser - - @property - def sql_runner(self) -> SqlExecuteRunner: - """A runner which is used internally by the `execute_sql` function of `dbt.lib`. - The runners `node` attribute can be updated before calling `compile` or `compile_and_execute`. - """ - if self._sql_runner is None: - self._sql_runner = SqlExecuteRunner( - self.config, self.adapter, node=None, node_index=1, num_nodes=1 - ) - return self._sql_runner - - @property - def sql_compiler(self) -> SqlCompileRunner: - """A runner which is used internally by the `compile_sql` function of `dbt.lib`. - The runners `node` attribute can be updated before calling `compile` or `compile_and_execute`. - """ - if self._sql_compiler is None: - self._sql_compiler = SqlCompileRunner( - self.config, self.adapter, node=None, node_index=1, num_nodes=1 - ) - return self._sql_compiler - - @property - def project_name(self) -> str: - """dbt project name""" - return self.config.project_name - - @property - def project_root(self) -> str: - """dbt project root""" - return self.config.project_root - - @property - def manifest(self) -> ManifestProxy: - """dbt manifest dict""" - return ManifestProxy(self.dbt.flat_graph) - - def safe_parse_project(self) -> None: - self.clear_caches() - # reinit the project because config may change - # this operation is cheap anyway - self.init_project() - # doing this so that we can allow inits to fail when config is - # bad and restart after the user sets it up correctly - if hasattr(self, "config"): - _config_pointer = copy(self.config) - else: - _config_pointer = None - try: - self.parse_project() - self.write_manifest_artifact() - - if self.defer_to_prod: - if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8: - writable_manifest = WritableManifest.read_and_check_versions(self.defer_to_prod_manifest_path) - manifest = Manifest.from_writable_manifest(writable_manifest) - self.dbt.merge_from_artifact( - other=manifest, - ) - else: - with open(self.defer_to_prod_manifest_path) as f: - manifest = WritableManifest.from_dict(json.load(f)) - selected = set() - self.dbt.merge_from_artifact( - self.adapter, - other=manifest, - selected=selected, - favor_state=self.favor_state, - ) - except Exception as e: - self.config = _config_pointer - raise Exception(str(e)) - - def write_manifest_artifact(self) -> None: - """Write a manifest.json to disk""" - artifact_path = os.path.join( - self.config.project_root, self.config.target_path, MANIFEST_ARTIFACT - ) - self.dbt.write(artifact_path) - - def clear_caches(self) -> None: - """Clear least recently used caches and reinstantiable container objects""" - self.get_ref_node.cache_clear() - self.get_source_node.cache_clear() - self.get_macro_function.cache_clear() - self.get_columns.cache_clear() - - @lru_cache(maxsize=10) - def get_ref_node(self, target_model_name: str) -> "ManifestNode": - """Get a `"ManifestNode"` from a dbt project model name""" - try: - if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 6: - return self.dbt.resolve_ref( - source_node=None, - target_model_name=target_model_name, - target_model_version=None, - target_model_package=None, - current_project=self.config.project_name, - node_package=self.config.project_name, - ) - if DBT_MAJOR_VER == 1 and DBT_MINOR_VER >= 5: - return self.dbt.resolve_ref( - target_model_name=target_model_name, - target_model_version=None, - target_model_package=None, - current_project=self.config.project_name, - node_package=self.config.project_name, - ) - return self.dbt.resolve_ref( - target_model_name=target_model_name, - target_model_package=None, - current_project=self.config.project_name, - node_package=self.config.project_name, - ) - except Exception as e: - raise Exception(str(e)) - - @lru_cache(maxsize=10) - def get_source_node( - self, target_source_name: str, target_table_name: str - ) -> "ManifestNode": - """Get a `"ManifestNode"` from a dbt project source name and table name""" - try: - return self.dbt.resolve_source( - target_source_name=target_source_name, - target_table_name=target_table_name, - current_project=self.config.project_name, - node_package=self.config.project_name, - ) - except Exception as e: - raise Exception(str(e)) - - def get_server_node(self, sql: str, node_name="name", original_node: Optional[Union["ManifestNode", str]] = None): - """Get a node for SQL execution against adapter""" - self._clear_node(node_name) - sql_node = self.sql_parser.parse_remote(sql, node_name) - # Enable copying original node properties - if original_node is not None: - if isinstance(original_node, str): - original_node = self.get_ref_node(original_node) - if original_node is not None and isinstance(original_node.node_info, dict) and "materialized" in original_node.node_info.keys() and original_node.node_info["materialized"] == "incremental": - sql_node.schema = original_node.schema - sql_node.database = original_node.database - sql_node.alias = original_node.alias - sql_node.node_info["materialized"] = "incremental" - sql_node.node_info.update({k: v for k, v in original_node.node_info.items() if k not in sql_node.node_info.keys()}) - process_node(self.config, self.dbt, sql_node) - return sql_node - - @lru_cache(maxsize=100) - def get_macro_function(self, macro_name: str, compiled_code: Optional[str] = None) -> Callable[[Dict[str, Any]], Any]: - """Get macro as a function which takes a dict via argument named `kwargs`, - ie: `kwargs={"relation": ...}` - - make_schema_fn = get_macro_function('make_schema')\n - make_schema_fn({'name': '__test_schema_1'})\n - make_schema_fn({'name': '__test_schema_2'})""" - if DBT_MAJOR_VER >= 1 and DBT_MINOR_VER >= 8: - model_context = {} - if compiled_code is not None: - model_context["compiled_code"] = compiled_code - return partial( - self.adapter.execute_macro, macro_name=macro_name, context_override=model_context, - ) - else: - return partial( - self.adapter.execute_macro, macro_name=macro_name, manifest=self.dbt - ) - - def adapter_execute( - self, sql: str, auto_begin: bool = True, fetch: bool = False - ) -> Tuple["AdapterResponse", agate.Table]: - """Wraps adapter.execute. Execute SQL against database""" - return self.adapter.execute(sql, auto_begin, fetch) - - def execute_macro( - self, - macro: str, - kwargs: Optional[Dict[str, Any]] = None, - compiled_code: Optional[str] = None - ) -> Any: - """Wraps adapter execute_macro. Execute a macro like a function.""" - return self.get_macro_function(macro, compiled_code)(kwargs=kwargs) - - def execute_sql(self, raw_sql: str, original_node: Optional[Union["ManifestNode", str]] = None) -> DbtAdapterExecutionResult: - """Execute dbt SQL statement against database""" - with self.adapter.connection_named("master"): - # if no jinja chars then these are synonymous - compiled_sql = raw_sql - if has_jinja(raw_sql): - # jinja found, compile it - compilation_result = self._compile_sql(raw_sql, original_node) - compiled_sql = compilation_result.compiled_sql - - return DbtAdapterExecutionResult( - *self.adapter_execute(compiled_sql, fetch=True), - raw_sql, - compiled_sql, - ) - - def execute_node(self, node: "ManifestNode") -> DbtAdapterExecutionResult: - """Execute dbt SQL statement against database from a"ManifestNode""" - try: - if node is None: - raise ValueError("This model doesn't exist within this dbt project") - raw_sql: str = getattr(node, RAW_CODE) - compiled_sql: Optional[str] = getattr(node, COMPILED_CODE, None) - if compiled_sql: - # node is compiled, execute the SQL - return self.execute_sql(compiled_sql) - # node not compiled - if has_jinja(raw_sql): - # node has jinja in its SQL, compile it - compiled_sql = self._compile_node(node).compiled_sql - # execute the SQL - return self.execute_sql(compiled_sql or raw_sql) - except Exception as e: - raise Exception(str(e)) - - def compile_sql(self, raw_sql: str, original_node: Optional["ManifestNode"] = None) -> DbtAdapterCompilationResult: - try: - with self.adapter.connection_named("master"): - return self._compile_sql(raw_sql, original_node) - except Exception as e: - raise Exception(str(e)) - - def compile_node( - self, node: "ManifestNode" - ) -> Optional[DbtAdapterCompilationResult]: - try: - if node is None: - raise ValueError("This model doesn't exist within this dbt project") - with self.adapter.connection_named("master"): - return self._compile_node(node) - except Exception as e: - raise Exception(str(e)) - - def _compile_sql(self, raw_sql: str, original_node: Optional[Union["ManifestNode", str]] = None) -> DbtAdapterCompilationResult: - """Creates a node with a `dbt.parser.sql` class. Compile generated node.""" - try: - temp_node_id = str("t_" + uuid.uuid4().hex) - server_node = self.get_server_node(raw_sql, temp_node_id, original_node) - node = self._compile_node(server_node) - self._clear_node(temp_node_id) - return node - except Exception as e: - raise Exception(str(e)) - - def _compile_node( - self, node: Union["ManifestNode", "CompiledNode"] - ) -> Optional[DbtAdapterCompilationResult]: - """Compiles existing node.""" - try: - self.sql_compiler.node = copy(node) - if DBT_MAJOR_VER == 1 and DBT_MINOR_VER <= 3: - compiled_node = ( - node - if isinstance(node, CompiledNode) - else self.sql_compiler.compile(self.dbt) - ) - else: - # this is essentially a convenient wrapper to adapter.get_compiler - compiled_node = self.sql_compiler.compile(self.dbt) - return DbtAdapterCompilationResult( - getattr(compiled_node, RAW_CODE), - getattr(compiled_node, COMPILED_CODE), - compiled_node, - ) - except Exception as e: - raise Exception(str(e)) - - def _clear_node(self, name="name"): - """Removes the statically named node created by `execute_sql` and `compile_sql` in `dbt.lib`""" - if self.dbt is not None: - self.dbt.nodes.pop( - f"{NodeType.SqlOperation}.{self.project_name}.{name}", None - ) - - def get_relation( - self, database: Optional[str], schema: Optional[str], name: Optional[str] - ) -> Optional["BaseRelation"]: - """Wrapper for `adapter.get_relation`""" - return self.adapter.get_relation(database, schema, name) - - def create_relation( - self, database: Optional[str], schema: Optional[str], name: Optional[str] - ) -> "BaseRelation": - """Wrapper for `adapter.Relation.create`""" - return self.adapter.Relation.create(database, schema, name) - - def create_relation_from_node(self, node: "ManifestNode") -> "BaseRelation": - """Wrapper for `adapter.Relation.create_from`""" - return self.adapter.Relation.create_from(self.config, node) - - def get_columns_in_relation(self, relation: "BaseRelation") -> List[str]: - """Wrapper for `adapter.get_columns_in_relation`""" - try: - with self.adapter.connection_named("master"): - return self.adapter.get_columns_in_relation(relation) - except Exception as e: - raise Exception(str(e)) - - @lru_cache(maxsize=5) - def get_columns(self, node: "ManifestNode") -> List["ColumnInfo"]: - """Get a list of columns from a compiled node""" - columns = [] - try: - columns.extend( - [ - c.name - for c in self.get_columns_in_relation( - self.create_relation_from_node(node) - ) - ] - ) - except Exception: - original_sql = str(getattr(node, RAW_CODE)) - # TODO: account for `TOP` syntax - setattr(node, RAW_CODE, f"select * from ({original_sql}) limit 0") - result = self.execute_node(node) - setattr(node, RAW_CODE, original_sql) - delattr(node, COMPILED_CODE) - columns.extend(result.table.column_names) - return columns - - def get_catalog(self) -> Dict[str, Any]: - """Get catalog from adapter""" - catalog_table: agate.Table = agate.Table([]) - catalog_data: List[PrimitiveDict] = [] - exceptions: List[Exception] = [] - try: - with self.adapter.connection_named("generate_catalog"): - catalog_table, exceptions = self.adapter.get_catalog(self.dbt) - - if exceptions: - raise Exception(str(exceptions)) - - catalog_data = [ - dict( - zip(catalog_table.column_names, map(dbt.utils._coerce_decimal, row)) - ) - for row in catalog_table - ] - - except Exception as e: - raise Exception(str(e)) - return catalog_data - - def get_or_create_relation( - self, database: str, schema: str, name: str - ) -> Tuple["BaseRelation", bool]: - """Get relation or create if not exists. Returns tuple of relation and - boolean result of whether it existed ie: (relation, did_exist)""" - ref = self.get_relation(database, schema, name) - return ( - (ref, True) - if ref - else (self.create_relation(database, schema, name), False) - ) - - def create_schema(self, node: "ManifestNode"): - """Create a schema in the database""" - return self.execute_macro( - "create_schema", - kwargs={"relation": self.create_relation_from_node(node)}, - ) - - def materialize( - self, node: "ManifestNode", temporary: bool = True - ) -> Tuple["AdapterResponse", None]: - """Materialize a table in the database""" - return self.adapter_execute( - # Returns CTAS string so send to adapter.execute - self.execute_macro( - "create_table_as", - kwargs={ - "sql": getattr(node, COMPILED_CODE), - "relation": self.create_relation_from_node(node), - "temporary": temporary, - }, - ), - auto_begin=True, - ) - - def get_dbt_version(self): - return [DBT_MAJOR_VER, DBT_MINOR_VER, DBT_PATCH_VER] - - def validate_sql_dry_run(self, compiled_sql: str): - if DBT_MAJOR_VER < 1: - return None - if DBT_MINOR_VER < 6: - return None - try: - return self.adapter.validate_sql(compiled_sql) - except Exception as e: - raise Exception(str(e)) - - def get_target_names(self): - from dbt.config.profile import read_profile - profile = read_profile(self.args.profiles_dir) - profile = profile[self.config.profile_name] - if "outputs" in profile: - outputs = profile["outputs"] - return outputs.keys() - return [] - - def set_selected_target(self, target: str): - self.args.target = target - - def cleanup_connections(self): - try: - self.adapter.cleanup_connections() - except Exception as e: - raise Exception(str(e)) diff --git a/dbt_healthcheck.py b/dbt_healthcheck.py deleted file mode 100644 index 0580f6c2a..000000000 --- a/dbt_healthcheck.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Optional - -def project_healthcheck( - manifest_path, catalog_path=None, config_path=None, config=None, token=None, tenant=None, backend_url: Optional[str] = None, -): - try: - import logging - import json - - from datapilot.config.config import load_config - from datapilot.core.platforms.dbt.utils import load_catalog - from datapilot.core.platforms.dbt.utils import load_manifest - from datapilot.core.platforms.dbt.constants import MODEL - from datapilot.core.platforms.dbt.executor import DBTInsightGenerator - - logging.basicConfig(level=logging.INFO) - manifest = load_manifest(manifest_path) - catalog = load_catalog(catalog_path) if catalog_path else None - if not config and config_path: - config = load_config(config_path) - insight_generator = DBTInsightGenerator( - manifest=manifest, - catalog=catalog, - config=config, - token=token, - instance_name=tenant, - backend_url=backend_url, - ) - reports = insight_generator.run() - - # package_insights = reports[PROJECT] - model_insights = { - k: [json.loads(item.json()) for item in v] - for k, v in reports[MODEL].items() - } - - return {"model_insights": model_insights} - - except Exception as e: - raise Exception(str(e)) diff --git a/jest.config.js b/jest.config.js index 6b9713756..dc92011d2 100644 --- a/jest.config.js +++ b/jest.config.js @@ -17,5 +17,11 @@ module.exports = { moduleNameMapper: { "^vscode$": "/src/test/mock/vscode.ts", "^@lib$": "/src/test/mock/lib.ts", + "^node-fetch$": "/src/test/mock/node-fetch.ts", + // Development: use local TypeScript source (same as webpack and tsconfig) + // "^@altimateai/dbt-integration$": + // "/../altimate-dbt-integration/src/index.ts", + // Production: use npm package (commented out for development) + "^@altimateai/dbt-integration$": "@altimateai/dbt-integration", }, }; diff --git a/package-lock.json b/package-lock.json index 1a63c5ba3..f2489508c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10,6 +10,7 @@ "hasInstallScript": true, "license": "MIT", "dependencies": { + "@altimateai/dbt-integration": "^0.0.6", "@jupyterlab/coreutils": "^6.2.4", "@jupyterlab/nbformat": "^4.2.4", "@jupyterlab/services": "^6.6.7", @@ -88,6 +89,22 @@ "vscode": "^1.95.0" } }, + "node_modules/@altimateai/dbt-integration": { + "version": "0.0.6", + "resolved": "https://registry.npmjs.org/@altimateai/dbt-integration/-/dbt-integration-0.0.6.tgz", + "integrity": "sha512-Ng1iF7c03hEs0RRtx73AhH/2kru8ZjpNGyx+UYYjLaaDcFC2loV4nY70Y2rH8DUWTgM8jvExSMau2aHAjw0kbQ==", + "license": "MIT", + "dependencies": { + "node-abort-controller": "^3.1.1", + "node-fetch": "^3.3.2", + "python-bridge": "git+https://github.com/mdesmet/node-python-bridge.git#feat/detached", + "semver": "^7.6.3", + "yaml": "^2.5.0" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/@ampproject/remapping": { "version": "2.3.0", "resolved": "https://registry.npmjs.org/@ampproject/remapping/-/remapping-2.3.0.tgz", diff --git a/package.json b/package.json index 51b70ddd5..351a73d61 100644 --- a/package.json +++ b/package.json @@ -1392,6 +1392,7 @@ "altimateai.vscode-altimate-mcp-server" ], "dependencies": { + "@altimateai/dbt-integration": "^0.0.6", "@jupyterlab/coreutils": "^6.2.4", "@jupyterlab/nbformat": "^4.2.4", "@jupyterlab/services": "^6.6.7", diff --git a/src/altimate.ts b/src/altimate.ts index c2ac7c901..799708083 100644 --- a/src/altimate.ts +++ b/src/altimate.ts @@ -1,32 +1,18 @@ +import { + AltimateHttpClient, + ColumnMetaData, + DBTConfiguration, + DBTTerminal, + NodeMetaData, + SourceMetaData, +} from "@altimateai/dbt-integration"; +import { NotebookItem, NotebookSchema, PreconfiguredNotebookItem } from "@lib"; +import { inject } from "inversify"; import type { RequestInit } from "node-fetch"; -import { CommentThread, env, Uri, window, workspace } from "vscode"; -import { provideSingleton, processStreamResponse } from "./utils"; -import { ColumnMetaData, NodeMetaData, SourceMetaData } from "./domain"; -import { TelemetryService } from "./telemetry"; -import { join } from "path"; -import { createReadStream, createWriteStream, mkdirSync, ReadStream } from "fs"; -import * as os from "os"; -import { RateLimitException, ExecutionsExhaustedException } from "./exceptions"; -import { DBTProject } from "./manifest/dbtProject"; -import { DBTTerminal } from "./dbt_client/dbtTerminal"; -import { PythonEnvironment } from "./manifest/pythonEnvironment"; -import { PreconfiguredNotebookItem, NotebookItem, NotebookSchema } from "@lib"; import * as vscode from "vscode"; -export class NoCredentialsError extends Error {} - -export class NotFoundError extends Error {} - export class UserInputError extends Error {} -export class ForbiddenError extends Error { - constructor() { - super("Invalid credentials. Please check instance name and API Key."); - } -} - -export class APIError extends Error {} - export interface ColumnLineage { source: { uniqueId: string; column_name: string }; target: { uniqueId: string; column_name: string }; @@ -252,25 +238,6 @@ export interface DBTCoreIntegration { last_file_upload_time: string; } -interface DownloadArtifactResponse { - url: string; - dbt_core_integration_file_id: number; -} - -export type ValidateSqlParseErrorType = - | "sql_parse_error" - | "sql_invalid_error" - | "sql_unknown_error"; - -export interface ValidateSqlParseErrorResponse { - error_type?: ValidateSqlParseErrorType; - errors: { - description: string; - start_position?: [number, number]; - end_position?: [number, number]; - }[]; -} - export interface TenantUser { id: number; uuid: string; @@ -296,15 +263,6 @@ interface BulkDocsPropRequest { session_id: string; } -interface AltimateConfig { - key: string; - instance: string; -} - -enum PromptAnswer { - YES = "Get your free API Key", -} - export interface SharedDoc { share_id: number; name: string; @@ -341,81 +299,32 @@ export interface ConversationGroup { conversations: Conversation[]; } -@provideSingleton(AltimateRequest) export class AltimateRequest { - public static ALTIMATE_URL = workspace - .getConfiguration("dbt") - .get("altimateUrl", "https://api.myaltimate.com"); - constructor( - private telemetry: TelemetryService, private dbtTerminal: DBTTerminal, - private pythonEnvironment: PythonEnvironment, + @inject("DBTConfiguration") + private dbtConfiguration: DBTConfiguration, + private altimateHttpClient: AltimateHttpClient, ) {} - private async internalFetch(url: string, init?: RequestInit) { - const nodeFetch = (await import("node-fetch")).default; - return nodeFetch(url, init); + public getAltimateUrl(): string { + return this.altimateHttpClient.getAltimateUrl(); + } + + private async internalFetch(url: string, init?: RequestInit) { + return this.altimateHttpClient.internalFetch(url, init); } getInstanceName() { - return this.pythonEnvironment.getResolvedConfigValue( - "altimateInstanceName", - ); + return this.dbtConfiguration.getAltimateInstanceName(); } getAIKey() { - return this.pythonEnvironment.getResolvedConfigValue("altimateAiKey"); + return this.dbtConfiguration.getAltimateAiKey(); } public enabled(): boolean { - return !!this.getConfig(); - } - - private async showAPIKeyMessage(message: string) { - const answer = await window.showInformationMessage( - message, - PromptAnswer.YES, - ); - if (answer === PromptAnswer.YES) { - env.openExternal( - Uri.parse("https://app.myaltimate.com/register?source=extension"), - ); - } - } - - private getConfig(): AltimateConfig | undefined { - const key = this.getAIKey(); - const instance = this.getInstanceName(); - if (!key || !instance) { - return undefined; - } - return { key, instance }; - } - - getCredentialsMessage(): string | undefined { - const key = this.getAIKey(); - const instance = this.getInstanceName(); - - if (!key && !instance) { - return `To use this feature, please add an API Key and an instance name in the settings.`; - } - if (!key) { - return `To use this feature, please add an API key in the settings.`; - } - if (!instance) { - return `To use this feature, please add an instance name in the settings.`; - } - return; - } - - handlePreviewFeatures(): boolean { - const message = this.getCredentialsMessage(); - if (!message) { - return true; - } - this.showAPIKeyMessage(message); - return false; + return !!this.altimateHttpClient.getConfig(); } async fetchAsStream( @@ -424,103 +333,12 @@ export class AltimateRequest { onProgress: (response: string) => void, timeout: number = 120000, ) { - this.throwIfLocalMode(endpoint); - const url = `${AltimateRequest.ALTIMATE_URL}/${endpoint}`; - this.dbtTerminal.debug("fetchAsStream:request", url, request); - const config = this.getConfig()!; - const abortController = new AbortController(); - const timeoutHandler = setTimeout(() => { - abortController.abort(); - }, timeout); - try { - const response = await this.internalFetch(url, { - method: "POST", - body: JSON.stringify(request), - signal: abortController.signal, - headers: { - "x-tenant": config.instance, - Authorization: "Bearer " + config.key, - "Content-Type": "application/json", - }, - }); - - if (response.ok && response.status === 200) { - if (!response?.body) { - this.dbtTerminal.debug("fetchAsStream", "empty response"); - return null; - } - const responseText = await processStreamResponse( - response.body, - onProgress, - ); - - return responseText; - } - if ( - // response codes when backend authorization fails - response.status === 401 || - response.status === 403 - ) { - this.telemetry.sendTelemetryEvent("invalidCredentials", { url }); - throw new ForbiddenError(); - } - if (response.status === 404) { - this.telemetry.sendTelemetryEvent("resourceNotFound", { url }); - throw new NotFoundError("Resource Not found"); - } - if (response.status === 402) { - const jsonResponse = (await response.json()) as { detail: string }; - throw new ExecutionsExhaustedException(jsonResponse.detail); - } - const textResponse = await response.text(); - this.dbtTerminal.debug( - "network:response", - "error from backend", - textResponse, - ); - if (response.status === 429) { - throw new RateLimitException( - textResponse, - response.headers.get("Retry-After") - ? parseInt(response.headers.get("Retry-After") || "") - : 1 * 60 * 1000, // default to 1 min - ); - } - this.telemetry.sendTelemetryError("apiError", { - endpoint, - status: response.status, - textResponse, - }); - throw new APIError( - `Could not process request, server responded with ${response.status}: ${textResponse}`, - ); - } catch (error) { - this.dbtTerminal.error( - "apiCatchAllError", - "fetchAsStream catchAllError", - error, - true, - { - endpoint, - }, - ); - throw error; - } finally { - clearTimeout(timeoutHandler); - } - return null; - } - - private async readStreamToBlob(stream: ReadStream) { - return new Promise((resolve, reject) => { - const chunks: any[] = []; - stream.on("data", (chunk) => chunks.push(chunk)); - stream.on("end", () => { - const blob = new Blob(chunks); - resolve(blob); - }); - stream.on("error", reject); - }); + return this.altimateHttpClient.fetchAsStream( + endpoint, + request, + onProgress, + timeout, + ); } async uploadToS3( @@ -528,65 +346,38 @@ export class AltimateRequest { fetchArgs: Record, filePath: string, ) { - this.dbtTerminal.debug("uploadToS3:", endpoint, fetchArgs, filePath); - this.throwIfLocalMode(endpoint); - - const blob = (await this.readStreamToBlob( - createReadStream(filePath), - )) as Blob; - const response = await this.internalFetch(endpoint, { - ...fetchArgs, - method: "PUT", - body: blob, - }); - - this.dbtTerminal.debug( - "uploadToS3:response:", - `${response.status}`, - response.statusText, - ); - if (!response.ok || response.status !== 200) { - const textResponse = await response.text(); - this.telemetry.sendTelemetryError("uploadToS3", { - endpoint, - status: response.status, - textResponse, - }); - throw new Error("Failed to upload data to signed url"); - } - - return response; + return this.altimateHttpClient.uploadToS3(endpoint, fetchArgs, filePath); } private throwIfLocalMode(endpoint: string) { - const isLocalMode = workspace - .getConfiguration("dbt") - .get("isLocalMode", false); - if (!isLocalMode) { - return; - } - endpoint = endpoint.split("?")[0]; - if ( - [/^dbtconfig\/datapilot_version\/.*$/, /^dbtconfig\/.*\/download$/].some( - (regex) => regex.test(endpoint), - ) - ) { - return; - } - if ( - [ - "auth_health", - "dbtconfig", - "dbt/v1/fetch_artifact_url", - "dbtconfig/extension/start_scan", - "dbt/v1/project_integrations", - "dbt/v1/defer_to_prod_event", - "dbt/v3/validate-credentials", - ].includes(endpoint) - ) { - return; + try { + this.altimateHttpClient.throwIfLocalMode(endpoint); + } catch (error) { + // Apply additional local mode exceptions specific to AltimateRequest + endpoint = endpoint.split("?")[0]; + if ( + [ + /^dbtconfig\/datapilot_version\/.*$/, + /^dbtconfig\/.*\/download$/, + ].some((regex) => regex.test(endpoint)) + ) { + return; + } + if ( + [ + "auth_health", + "dbtconfig", + "dbt/v1/fetch_artifact_url", + "dbtconfig/extension/start_scan", + "dbt/v1/project_integrations", + "dbt/v1/defer_to_prod_event", + "dbt/v3/validate-credentials", + ].includes(endpoint) + ) { + return; + } + throw error; } - throw new Error("This feature is not supported in local mode."); } async fetch( @@ -596,133 +387,11 @@ export class AltimateRequest { ): Promise { this.dbtTerminal.debug("network:request", endpoint, fetchArgs); this.throwIfLocalMode(endpoint); - - const abortController = new AbortController(); - const timeoutHandler = setTimeout(() => { - abortController.abort(); - }, timeout); - - const message = this.getCredentialsMessage(); - if (message) { - throw new NoCredentialsError(message); - } - const config = this.getConfig()!; - - try { - const url = `${AltimateRequest.ALTIMATE_URL}/${endpoint}`; - const response = await this.internalFetch(url, { - method: "GET", - ...fetchArgs, - signal: abortController.signal, - headers: { - "x-tenant": config.instance, - Authorization: "Bearer " + config.key, - "Content-Type": "application/json", - }, - }); - this.dbtTerminal.debug("network:response", endpoint, response.status); - if (response.ok && response.status === 200) { - const jsonResponse = await response.json(); - return jsonResponse as T; - } - if ( - // response codes when backend authorization fails - response.status === 401 || - response.status === 403 - ) { - this.telemetry.sendTelemetryEvent("invalidCredentials", { url }); - throw new ForbiddenError(); - } - if (response.status === 404) { - this.telemetry.sendTelemetryEvent("resourceNotFound", { url }); - throw new NotFoundError("Resource Not found"); - } - if (response.status === 402) { - const jsonResponse = (await response.json()) as { detail: string }; - throw new ExecutionsExhaustedException(jsonResponse.detail); - } - const textResponse = await response.text(); - this.dbtTerminal.debug( - "network:response", - "error from backend", - textResponse, - ); - if (response.status === 429) { - throw new RateLimitException( - textResponse, - response.headers.get("Retry-After") - ? parseInt(response.headers.get("Retry-After") || "") - : 1 * 60 * 1000, // default to 1 min - ); - } - this.telemetry.sendTelemetryError("apiError", { - endpoint, - status: response.status, - textResponse, - }); - let jsonResponse: any; - try { - jsonResponse = JSON.parse(textResponse); - } catch {} - throw new APIError( - `Could not process request, server responded with ${response.status}: ${jsonResponse?.detail || textResponse}`, - ); - } catch (e) { - this.dbtTerminal.error("apiCatchAllError", "catchAllError", e, true, { - endpoint, - }); - throw e; - } finally { - clearTimeout(timeoutHandler); - } - } - - async downloadFileLocally( - artifactUrl: string, - projectRoot: Uri, - fileName = "manifest.json", - ): Promise { - const hashedProjectRoot = DBTProject.hashProjectRoot(projectRoot.fsPath); - const tempFolder = join(os.tmpdir(), hashedProjectRoot); - - try { - this.dbtTerminal.debug( - "AltimateRequest", - `creating temporary folder: ${tempFolder} for file: ${fileName}`, - ); - mkdirSync(tempFolder, { recursive: true }); - - const destinationPath = join(tempFolder, fileName); - - this.dbtTerminal.debug( - "AltimateRequest", - `fetching artifactUrl: ${artifactUrl}`, - ); - const response = await this.internalFetch(artifactUrl, { - agent: undefined, - }); - - if (!response.ok) { - throw new Error(`Failed to download file: ${response.statusText}`); - } - const fileStream = createWriteStream(destinationPath); - await new Promise((resolve, reject) => { - response.body?.pipe(fileStream); - response.body?.on("error", reject); - fileStream.on("finish", resolve); - }); - - this.dbtTerminal.debug("File downloaded successfully!", fileName); - return tempFolder; - } catch (err) { - this.dbtTerminal.error( - "downloadFileLocally", - `Could not save ${fileName} locally`, - err, - ); - window.showErrorMessage(`Could not save ${fileName} locally: ${err}`); - throw err; - } + return this.altimateHttpClient.fetch( + endpoint, + fetchArgs as RequestInit, + timeout, + ); } private getQueryString = ( @@ -778,7 +447,7 @@ export class AltimateRequest { } async validateCredentials(instance: string, key: string) { - const url = `${AltimateRequest.ALTIMATE_URL}/dbt/v3/validate-credentials`; + const url = `${this.getAltimateUrl()}/dbt/v3/validate-credentials`; const response = await fetch(url, { method: "GET", headers: { @@ -791,7 +460,7 @@ export class AltimateRequest { } async checkApiConnectivity() { - const url = `${AltimateRequest.ALTIMATE_URL}/health`; + const url = `${this.getAltimateUrl()}/health`; try { const response = await this.internalFetch(url, { method: "GET" }); const { status } = (await response.json()) as { status: string }; @@ -813,22 +482,6 @@ export class AltimateRequest { return this.fetch("dbt/v1/project_integrations"); } - async sendDeferToProdEvent(defer_type: string) { - return this.fetch("dbt/v1/defer_to_prod_event", { - method: "POST", - body: JSON.stringify({ defer_type }), - }); - } - - async fetchArtifactUrl(artifact_type: string, dbtCoreIntegrationId: number) { - return this.fetch( - `dbt/v1/fetch_artifact_url${this.getQueryString({ - artifact_type: artifact_type, - dbt_core_integration_id: dbtCoreIntegrationId, - })}`, - ); - } - async getHealthcheckConfigs() { return this.fetch( `dbtconfig${this.getQueryString({ size: "100" })}`, @@ -923,8 +576,6 @@ export class AltimateRequest { data: { name: string; description?: string; - uri?: Uri; - model?: string; }, projectName: string, ) { diff --git a/src/autocompletion_provider/docAutocompletionProvider.ts b/src/autocompletion_provider/docAutocompletionProvider.ts index 31ddc6bb0..fb2493243 100755 --- a/src/autocompletion_provider/docAutocompletionProvider.ts +++ b/src/autocompletion_provider/docAutocompletionProvider.ts @@ -11,12 +11,11 @@ import { TextDocument, Uri, } from "vscode"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ManifestCacheChangedEvent } from "../manifest/event/manifestCacheChangedEvent"; -import { isEnclosedWithinCodeBlock, provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { ManifestCacheChangedEvent } from "../dbt_client/event/manifestCacheChangedEvent"; import { TelemetryService } from "../telemetry"; +import { isEnclosedWithinCodeBlock } from "../utils"; -@provideSingleton(DocAutocompletionProvider) export class DocAutocompletionProvider implements CompletionItemProvider, Disposable { diff --git a/src/autocompletion_provider/index.ts b/src/autocompletion_provider/index.ts index d3bb3f5f4..4aab24180 100755 --- a/src/autocompletion_provider/index.ts +++ b/src/autocompletion_provider/index.ts @@ -1,13 +1,11 @@ -import { Disposable, DocumentFilter, languages } from "vscode"; +import { Disposable, languages } from "vscode"; import { DBTPowerUserExtension } from "../dbtPowerUserExtension"; -import { provideSingleton } from "../utils"; import { DocAutocompletionProvider } from "./docAutocompletionProvider"; import { MacroAutocompletionProvider } from "./macroAutocompletionProvider"; import { ModelAutocompletionProvider } from "./modelAutocompletionProvider"; import { SourceAutocompletionProvider } from "./sourceAutocompletionProvider"; import { UserCompletionProvider } from "./usercompletion_provider"; -@provideSingleton(AutocompletionProviders) export class AutocompletionProviders implements Disposable { private disposables: Disposable[] = []; diff --git a/src/autocompletion_provider/macroAutocompletionProvider.ts b/src/autocompletion_provider/macroAutocompletionProvider.ts index e8d9b51fb..95e71d7d7 100755 --- a/src/autocompletion_provider/macroAutocompletionProvider.ts +++ b/src/autocompletion_provider/macroAutocompletionProvider.ts @@ -11,12 +11,12 @@ import { TextDocument, Uri, } from "vscode"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ManifestCacheChangedEvent } from "../manifest/event/manifestCacheChangedEvent"; -import { isEnclosedWithinCodeBlock, provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { ManifestCacheChangedEvent } from "../dbt_client/event/manifestCacheChangedEvent"; import { TelemetryService } from "../telemetry"; +import { isEnclosedWithinCodeBlock } from "../utils"; -@provideSingleton(MacroAutocompletionProvider) // TODO autocomplete doesn't work when mistype, delete and retype +// TODO autocomplete doesn't work when mistype, delete and retype export class MacroAutocompletionProvider implements CompletionItemProvider, Disposable { diff --git a/src/autocompletion_provider/modelAutocompletionProvider.ts b/src/autocompletion_provider/modelAutocompletionProvider.ts index 1cc2a880d..7dd5273d5 100755 --- a/src/autocompletion_provider/modelAutocompletionProvider.ts +++ b/src/autocompletion_provider/modelAutocompletionProvider.ts @@ -1,3 +1,4 @@ +import { RESOURCE_TYPE_ANALYSIS } from "@altimateai/dbt-integration"; import { CancellationToken, CompletionContext, @@ -11,13 +12,12 @@ import { TextDocument, Uri, } from "vscode"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ManifestCacheChangedEvent } from "../manifest/event/manifestCacheChangedEvent"; -import { isEnclosedWithinCodeBlock, provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { ManifestCacheChangedEvent } from "../dbt_client/event/manifestCacheChangedEvent"; import { TelemetryService } from "../telemetry"; -import { DBTProject } from "../manifest/dbtProject"; +import { isEnclosedWithinCodeBlock } from "../utils"; -@provideSingleton(ModelAutocompletionProvider) // TODO autocomplete doesn't work when mistype, delete and retype +// TODO autocomplete doesn't work when mistype, delete and retype export class ModelAutocompletionProvider implements CompletionItemProvider, Disposable { @@ -141,9 +141,7 @@ export class ModelAutocompletionProvider const projectName = project.getProjectName(); const models = added.nodeMetaMap.nodes(); const autocompleteItems = Array.from(models) - .filter( - (model) => model.resource_type !== DBTProject.RESOURCE_TYPE_ANALYSIS, - ) + .filter((model) => model.resource_type !== RESOURCE_TYPE_ANALYSIS) .map((model) => ({ projectName, packageName: model.package_name, diff --git a/src/autocompletion_provider/sourceAutocompletionProvider.ts b/src/autocompletion_provider/sourceAutocompletionProvider.ts index e479f27c9..ee03847d3 100755 --- a/src/autocompletion_provider/sourceAutocompletionProvider.ts +++ b/src/autocompletion_provider/sourceAutocompletionProvider.ts @@ -11,12 +11,12 @@ import { TextDocument, Uri, } from "vscode"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ManifestCacheChangedEvent } from "../manifest/event/manifestCacheChangedEvent"; -import { isEnclosedWithinCodeBlock, provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { ManifestCacheChangedEvent } from "../dbt_client/event/manifestCacheChangedEvent"; import { TelemetryService } from "../telemetry"; +import { isEnclosedWithinCodeBlock } from "../utils"; -@provideSingleton(SourceAutocompletionProvider) // TODO autocomplete doesn't work when mistype, delete and retype +// TODO autocomplete doesn't work when mistype, delete and retype export class SourceAutocompletionProvider implements CompletionItemProvider, Disposable { diff --git a/src/autocompletion_provider/usercompletion_provider.ts b/src/autocompletion_provider/usercompletion_provider.ts index 54de1eb9e..672e0f326 100644 --- a/src/autocompletion_provider/usercompletion_provider.ts +++ b/src/autocompletion_provider/usercompletion_provider.ts @@ -6,10 +6,8 @@ import { Disposable, ProviderResult, } from "vscode"; -import { provideSingleton } from "../utils"; import { UsersService } from "../services/usersService"; -@provideSingleton(UserCompletionProvider) export class UserCompletionProvider implements CompletionItemProvider, Disposable { diff --git a/src/code_lens_provider/cteCodeLensProvider.ts b/src/code_lens_provider/cteCodeLensProvider.ts index 12714a518..f26ba4d58 100644 --- a/src/code_lens_provider/cteCodeLensProvider.ts +++ b/src/code_lens_provider/cteCodeLensProvider.ts @@ -1,14 +1,14 @@ +import { DBTTerminal } from "@altimateai/dbt-integration"; +import { inject } from "inversify"; import { CancellationToken, CodeLens, CodeLensProvider, Command, + Disposable, Range, TextDocument, - Disposable, } from "vscode"; -import { provideSingleton } from "../utils"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; import { AltimateRequest } from "../altimate"; export interface CteInfo { @@ -19,7 +19,6 @@ export interface CteInfo { withClauseStart: number; // Start position of the WITH clause } -@provideSingleton(CteCodeLensProvider) export class CteCodeLensProvider implements CodeLensProvider, Disposable { private disposables: Disposable[] = []; @@ -37,6 +36,7 @@ export class CteCodeLensProvider implements CodeLensProvider, Disposable { private static readonly MAX_COLUMN_LIST_LENGTH = 1000; constructor( + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, private altimate: AltimateRequest, ) {} diff --git a/src/code_lens_provider/documentationCodeLensProvider.ts b/src/code_lens_provider/documentationCodeLensProvider.ts index d0f32eb3e..6e3639081 100644 --- a/src/code_lens_provider/documentationCodeLensProvider.ts +++ b/src/code_lens_provider/documentationCodeLensProvider.ts @@ -7,14 +7,11 @@ import { ProviderResult, Range, TextDocument, - window, workspace, } from "vscode"; -import { provideSingleton } from "../utils"; import { CST, LineCounter, Parser } from "yaml"; import path = require("path"); -@provideSingleton(DocumentationCodeLensProvider) export class DocumentationCodeLensProvider implements CodeLensProvider { private _onDidChangeCodeLenses: EventEmitter = new EventEmitter(); public readonly onDidChangeCodeLenses: Event = diff --git a/src/code_lens_provider/index.ts b/src/code_lens_provider/index.ts index 5e7f5efe0..981109f56 100644 --- a/src/code_lens_provider/index.ts +++ b/src/code_lens_provider/index.ts @@ -1,13 +1,11 @@ import { Disposable, languages } from "vscode"; import { DBTPowerUserExtension } from "../dbtPowerUserExtension"; -import { provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { CteCodeLensProvider } from "./cteCodeLensProvider"; +import { DocumentationCodeLensProvider } from "./documentationCodeLensProvider"; import { SourceModelCreationCodeLensProvider } from "./sourceModelCreationCodeLensProvider"; import { VirtualSqlCodeLensProvider } from "./virtualSqlCodeLensProvider"; -import { DocumentationCodeLensProvider } from "./documentationCodeLensProvider"; -import { CteCodeLensProvider } from "./cteCodeLensProvider"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -@provideSingleton(CodeLensProviders) export class CodeLensProviders implements Disposable { private disposables: Disposable[] = []; constructor( diff --git a/src/code_lens_provider/sourceModelCreationCodeLensProvider.ts b/src/code_lens_provider/sourceModelCreationCodeLensProvider.ts index 4076e5ae1..7216df8dd 100644 --- a/src/code_lens_provider/sourceModelCreationCodeLensProvider.ts +++ b/src/code_lens_provider/sourceModelCreationCodeLensProvider.ts @@ -9,7 +9,6 @@ import { Uri, } from "vscode"; import { CST, LineCounter, Parser } from "yaml"; -import { provideSingleton } from "../utils"; interface Position { line: number; @@ -24,7 +23,6 @@ export interface GenerateModelFromSourceParams { tableIdentifier?: string; } -@provideSingleton(SourceModelCreationCodeLensProvider) export class SourceModelCreationCodeLensProvider implements CodeLensProvider { private codeLenses: CodeLens[] = []; private _onDidChangeCodeLenses: EventEmitter = new EventEmitter(); diff --git a/src/code_lens_provider/virtualSqlCodeLensProvider.ts b/src/code_lens_provider/virtualSqlCodeLensProvider.ts index 665886a72..98540981e 100644 --- a/src/code_lens_provider/virtualSqlCodeLensProvider.ts +++ b/src/code_lens_provider/virtualSqlCodeLensProvider.ts @@ -1,19 +1,17 @@ +import { NotebookService } from "@lib"; import { CancellationToken, CodeLens, CodeLensProvider, Command, + Disposable, Range, TextDocument, - Disposable, window, } from "vscode"; -import { provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; import { QueryManifestService } from "../services/queryManifestService"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { NotebookService } from "@lib"; -@provideSingleton(VirtualSqlCodeLensProvider) export class VirtualSqlCodeLensProvider implements CodeLensProvider, Disposable { diff --git a/src/commandProcessExecution.ts b/src/commandProcessExecution.ts deleted file mode 100755 index 1c8a5e2f8..000000000 --- a/src/commandProcessExecution.ts +++ /dev/null @@ -1,172 +0,0 @@ -import { spawn } from "child_process"; -import { provide } from "inversify-binding-decorators"; -import { CancellationToken, Disposable } from "vscode"; -import { DBTTerminal } from "./dbt_client/dbtTerminal"; -import { EnvironmentVariables } from "./domain"; - -@provide(CommandProcessExecutionFactory) -export class CommandProcessExecutionFactory { - constructor(private terminal: DBTTerminal) {} - - createCommandProcessExecution({ - command, - args, - stdin, - cwd, - tokens, - envVars, - }: { - command: string; - args?: string[]; - stdin?: string; - cwd?: string; - tokens?: CancellationToken[]; - envVars?: EnvironmentVariables; - }) { - return new CommandProcessExecution( - this.terminal, - command, - args, - stdin, - cwd, - tokens, - envVars, - ); - } -} - -export interface CommandProcessResult { - stdout: string; - stderr: string; - fullOutput: string; -} - -export class CommandProcessExecution { - private disposables: Disposable[] = []; - - constructor( - private terminal: DBTTerminal, - private command: string, - private args?: string[], - private stdin?: string, - private cwd?: string, - private tokens?: CancellationToken[], - private envVars?: EnvironmentVariables, - ) {} - - private spawn() { - const proc = spawn(this.command, this.args, { - cwd: this.cwd, - env: this.envVars, - }); - if (this.tokens !== undefined) { - this.tokens.forEach((token) => - this.disposables.push( - token.onCancellationRequested(() => { - proc.kill("SIGTERM"); - }), - ), - ); - } - return proc; - } - - private dispose() { - while (this.disposables.length) { - const x = this.disposables.pop(); - if (x) { - x.dispose(); - } - } - } - - async complete(): Promise { - return new Promise((resolve, reject) => { - this.terminal.debug( - "CommandProcessExecution", - "Going to execute command : " + this.command, - this.args, - ); - const commandProcess = this.spawn(); - let stdoutBuffer = ""; - let stderrBuffer = ""; - let fullOutput = ""; - commandProcess.stdout!.on("data", (chunk) => { - chunk = chunk.toString(); - stdoutBuffer += chunk; - fullOutput += chunk; - }); - commandProcess.stderr!.on("data", (chunk) => { - chunk = chunk.toString(); - stderrBuffer += chunk; - fullOutput += chunk; - }); - - commandProcess.once("close", () => { - this.terminal.debug( - "CommandProcessExecution", - "Return value from command: " + this.command, - this.args, - fullOutput, - ); - resolve({ stdout: stdoutBuffer, stderr: stderrBuffer, fullOutput }); - }); - - commandProcess.once("error", (error) => { - this.terminal.error( - "CommandProcessExecutionError", - "Command errored: " + this.command, - error, - true, - this.command, - this.args, - error, - ); - reject(new Error(`${error}`)); - }); - - if (this.stdin) { - commandProcess.stdin.write(this.stdin); - commandProcess.stdin.end(); - } - }); - } - - async completeWithTerminalOutput(): Promise { - return new Promise((resolve, reject) => { - const commandProcess = this.spawn(); - let stdoutBuffer = ""; - let stderrBuffer = ""; - let fullOutput = ""; - commandProcess.stdout!.on("data", (chunk) => { - const line = `${this.formatText(chunk.toString())}`; - stdoutBuffer += line; - this.terminal.log(line); - fullOutput += line; - }); - commandProcess.stderr!.on("data", (chunk) => { - const line = `${this.formatText(chunk.toString())}`; - stderrBuffer += line; - this.terminal.log(line); - fullOutput += line; - }); - commandProcess.once("close", () => { - resolve({ stdout: stdoutBuffer, stderr: stderrBuffer, fullOutput }); - this.terminal.log(""); - this.dispose(); - }); - commandProcess.once("error", (error) => { - reject(new Error(`Error occurred during process execution: ${error}`)); - }); - - if (this.stdin) { - commandProcess.stdin.write(this.stdin); - commandProcess.stdin.end(); - } - }); - } - - public formatText(text: string) { - return `${text.split(/(\r?\n)+/g).join("\r")}`; - } -} diff --git a/src/commands/altimateScan.ts b/src/commands/altimateScan.ts index 819d597e0..1bffb9f01 100644 --- a/src/commands/altimateScan.ts +++ b/src/commands/altimateScan.ts @@ -1,22 +1,21 @@ -import { ProgressLocation, Uri, commands, window } from "vscode"; +import { DBTTerminal } from "@altimateai/dbt-integration"; +import { inject } from "inversify"; +import { commands, ProgressLocation, Uri, window } from "vscode"; import { AltimateRequest } from "../altimate"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; import { ManifestCacheChangedEvent, ManifestCacheProjectAddedEvent, -} from "../manifest/event/manifestCacheChangedEvent"; +} from "../dbt_client/event/manifestCacheChangedEvent"; import { TelemetryService } from "../telemetry"; -import { provideSingleton } from "../utils"; import { InitCatalog } from "./tests/initCatalog"; -import { UndocumentedModelColumnTest } from "./tests/undocumentedModelColumnTest"; -import { StaleModelColumnTest } from "./tests/staleModelColumnTest"; import { MissingSchemaTest } from "./tests/missingSchemaTest"; -import { UnmaterializedModelTest } from "./tests/unmaterializedModelTest"; import { ScanContext } from "./tests/scanContext"; +import { StaleModelColumnTest } from "./tests/staleModelColumnTest"; import { AltimateScanStep } from "./tests/step"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; +import { UndocumentedModelColumnTest } from "./tests/undocumentedModelColumnTest"; +import { UnmaterializedModelTest } from "./tests/unmaterializedModelTest"; -@provideSingleton(AltimateScan) export class AltimateScan { private eventMap: Map = new Map(); private offlineAltimateScanSteps: AltimateScanStep[]; @@ -31,6 +30,7 @@ export class AltimateScan { private undocumentedModelColumnTest: UndocumentedModelColumnTest, private unmaterializedModelTest: UnmaterializedModelTest, private staleModelColumnTest: StaleModelColumnTest, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, ) { dbtProjectContainer.onManifestChanged((event) => diff --git a/src/commands/bigQueryCostEstimate.ts b/src/commands/bigQueryCostEstimate.ts index 646c84fc9..9718f0913 100644 --- a/src/commands/bigQueryCostEstimate.ts +++ b/src/commands/bigQueryCostEstimate.ts @@ -1,15 +1,16 @@ +import { DBTTerminal } from "@altimateai/dbt-integration"; +import { inject } from "inversify"; +import { PythonException } from "python-bridge"; import { window } from "vscode"; -import path = require("path"); -import { extendErrorWithSupportLinks, provideSingleton } from "../utils"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; import { TelemetryService } from "../telemetry"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { PythonException } from "python-bridge"; +import { extendErrorWithSupportLinks } from "../utils"; +import path = require("path"); -@provideSingleton(BigQueryCostEstimate) export class BigQueryCostEstimate { constructor( private dbtProjectContainer: DBTProjectContainer, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, private telemetry: TelemetryService, ) {} diff --git a/src/commands/index.ts b/src/commands/index.ts index 6d65bc774..325221eff 100644 --- a/src/commands/index.ts +++ b/src/commands/index.ts @@ -1,57 +1,56 @@ +import { DBTTerminal, RunModelType } from "@altimateai/dbt-integration"; +import { DatapilotNotebookController, OpenNotebookRequest } from "@lib"; +import { existsSync, readFileSync } from "fs"; +import { inject } from "inversify"; import { commands, CommentReply, CommentThread, + DecorationRangeBehavior, Disposable, + env, + extensions, languages, + ProgressLocation, + Range, TextEditor, + TextEditorDecorationType, + Uri, + version, ViewColumn, window, workspace, - version, - extensions, - Uri, - Range, - ProgressLocation, - TextEditorDecorationType, - DecorationRangeBehavior, - env, } from "vscode"; +import { AltimateRequest } from "../altimate"; +import { CteInfo } from "../code_lens_provider/cteCodeLensProvider"; +import { + ConversationCommentThread, + ConversationProvider, +} from "../comment_provider/conversationProvider"; import { SqlPreviewContentProvider } from "../content_provider/sqlPreviewContentProvider"; -import { RunModelType } from "../domain"; +import { DBTClient } from "../dbt_client"; +import { DBTProject } from "../dbt_client/dbtProject"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { PythonEnvironment } from "../dbt_client/pythonEnvironment"; +import { NotebookQuickPick } from "../quickpick/notebookQuickPick"; +import { ProjectQuickPickItem } from "../quickpick/projectQuickPick"; +import { DiagnosticsOutputChannel } from "../services/diagnosticsOutputChannel"; +import { QueryManifestService } from "../services/queryManifestService"; +import { SharedStateService } from "../services/sharedStateService"; import { deepEqual, extendErrorWithSupportLinks, getFirstWorkspacePath, getFormattedDateTime, - provideSingleton, } from "../utils"; +import { SQLLineagePanel } from "../webview_provider/sqlLineagePanel"; +import { AltimateScan } from "./altimateScan"; +import { BigQueryCostEstimate } from "./bigQueryCostEstimate"; import { RunModel } from "./runModel"; import { SqlToModel } from "./sqlToModel"; -import { AltimateScan } from "./altimateScan"; -import { WalkthroughCommands } from "./walkthroughCommands"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ProjectQuickPickItem } from "../quickpick/projectQuickPick"; import { ValidateSql } from "./validateSql"; -import { BigQueryCostEstimate } from "./bigQueryCostEstimate"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; -import { SharedStateService } from "../services/sharedStateService"; -import { - ConversationProvider, - ConversationCommentThread, -} from "../comment_provider/conversationProvider"; -import { PythonEnvironment } from "../manifest/pythonEnvironment"; -import { DBTClient } from "../dbt_client"; -import { existsSync, readFileSync } from "fs"; -import { DBTProject } from "../manifest/dbtProject"; -import { SQLLineagePanel } from "../webview_provider/sqlLineagePanel"; -import { QueryManifestService } from "../services/queryManifestService"; -import { AltimateRequest } from "../altimate"; -import { DatapilotNotebookController, OpenNotebookRequest } from "@lib"; -import { NotebookQuickPick } from "../quickpick/notebookQuickPick"; -import { CteInfo } from "../code_lens_provider/cteCodeLensProvider"; +import { WalkthroughCommands } from "./walkthroughCommands"; -@provideSingleton(VSCodeCommands) export class VSCodeCommands implements Disposable { private disposables: Disposable[] = []; @@ -63,9 +62,12 @@ export class VSCodeCommands implements Disposable { private altimateScan: AltimateScan, private walkthroughCommands: WalkthroughCommands, private bigQueryCostEstimate: BigQueryCostEstimate, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, + private diagnosticsOutputChannel: DiagnosticsOutputChannel, private eventEmitterService: SharedStateService, private conversationController: ConversationProvider, + @inject(PythonEnvironment) private pythonEnvironment: PythonEnvironment, private dbtClient: DBTClient, private sqlLineagePanel: SQLLineagePanel, @@ -452,13 +454,12 @@ export class VSCodeCommands implements Disposable { ), commands.registerCommand("dbtPowerUser.diagnostics", async () => { try { - await this.dbtTerminal.show(true); - await new Promise((resolve) => setTimeout(resolve, 1000)); - this.dbtTerminal.logLine("Diagnostics started..."); - this.dbtTerminal.logNewLine(); + this.diagnosticsOutputChannel.show(); + this.diagnosticsOutputChannel.logLine("Diagnostics started..."); + this.diagnosticsOutputChannel.logNewLine(); // Printing env vars - this.dbtTerminal.logBlockWithHeader( + this.diagnosticsOutputChannel.logBlockWithHeader( [ "Printing environment variables...", "* Please remove any sensitive information before sending it to us", @@ -467,17 +468,17 @@ export class VSCodeCommands implements Disposable { ([key, value]) => `${key}=${value}`, ), ); - this.dbtTerminal.logNewLine(); + this.diagnosticsOutputChannel.logNewLine(); - // Printing env vars - this.dbtTerminal.logBlockWithHeader( + // Printing python paths + this.diagnosticsOutputChannel.logBlockWithHeader( [ "Printing all python paths...", "* Please remove any sensitive information before sending it to us", ], this.pythonEnvironment.allPythonPaths.map(({ path }) => path), ); - this.dbtTerminal.logNewLine(); + this.diagnosticsOutputChannel.logNewLine(); // Printing extension settings const dbtSettings = workspace.getConfiguration().inspect("dbt"); @@ -489,7 +490,7 @@ export class VSCodeCommands implements Disposable { ...Object.keys(defaultValue), ...Object.keys(workspaceValue), ]; - this.dbtTerminal.logBlockWithHeader( + this.diagnosticsOutputChannel.logBlockWithHeader( [ "Printing extension settings...", "* Please remove any sensitive information before sending it to us", @@ -512,7 +513,7 @@ export class VSCodeCommands implements Disposable { return `${key}=${valueText}\t\t${overridenText}`; }), ); - this.dbtTerminal.logNewLine(); + this.diagnosticsOutputChannel.logNewLine(); // Printing extension and setup info const dbtIntegrationMode = workspace @@ -522,7 +523,7 @@ export class VSCodeCommands implements Disposable { .getConfiguration("dbt") .get("allowListFolders", []); const apiConnectivity = await this.altimate.checkApiConnectivity(); - this.dbtTerminal.logBlock([ + this.diagnosticsOutputChannel.logBlock([ `Python Path=${this.pythonEnvironment.pythonPath}`, `VSCode version=${version}`, `Extension version=${ @@ -537,81 +538,93 @@ export class VSCodeCommands implements Disposable { : "", `AllowList Folders=${allowListFolders}`, ]); - this.dbtTerminal.logNewLine(); + this.diagnosticsOutputChannel.logNewLine(); if (!this.dbtClient.pythonInstalled) { - this.dbtTerminal.logLine("Python is not installed"); - this.dbtTerminal.logLine( + this.diagnosticsOutputChannel.logLine("Python is not installed"); + this.diagnosticsOutputChannel.logLine( "Can't proceed further without fixing python installation", ); return; } - this.dbtTerminal.logLine("Python is installed"); + this.diagnosticsOutputChannel.logLine("Python is installed"); if (!this.dbtClient.dbtInstalled) { - this.dbtTerminal.logLine("DBT is not installed"); - this.dbtTerminal.logLine( + this.diagnosticsOutputChannel.logLine("DBT is not installed"); + this.diagnosticsOutputChannel.logLine( "Can't proceed further without fixing dbt installation", ); return; } - this.dbtTerminal.logLine("DBT is installed"); + this.diagnosticsOutputChannel.logLine("DBT is installed"); const dbtWorkspaces = this.dbtProjectContainer.dbtWorkspaceFolders; - this.dbtTerminal.logLine( + this.diagnosticsOutputChannel.logLine( `Number of workspaces=${dbtWorkspaces.length}`, ); for (const w of dbtWorkspaces) { - this.dbtTerminal.logHorizontalRule(); - this.dbtTerminal.logLine( + this.diagnosticsOutputChannel.logHorizontalRule(); + this.diagnosticsOutputChannel.logLine( `Workspace Path=${w.workspaceFolder.uri.fsPath}`, ); - this.dbtTerminal.logLine(`Adapters=${w.getAdapters()}`); - this.dbtTerminal.logLine( + this.diagnosticsOutputChannel.logLine( + `Adapters=${w.getAdapters()}`, + ); + this.diagnosticsOutputChannel.logLine( `AllowList Folders=${w.getAllowListFolders()}`, ); w.projectDiscoveryDiagnostics.forEach((uri, diagnostics) => { - this.dbtTerminal.logLine(`Problems for ${uri.fsPath}`); + this.diagnosticsOutputChannel.logLine( + `Problems for ${uri.fsPath}`, + ); diagnostics.forEach((d) => { - this.dbtTerminal.logLine( + this.diagnosticsOutputChannel.logLine( `source=${d.source}\tmessage=${d.message}`, ); }); }); - this.dbtTerminal.logHorizontalRule(); + this.diagnosticsOutputChannel.logHorizontalRule(); } const projects = this.dbtProjectContainer.getProjects(); - this.dbtTerminal.logLine(`Number of projects=${projects.length}`); + this.diagnosticsOutputChannel.logLine( + `Number of projects=${projects.length}`, + ); if (projects.length === 0) { - this.dbtTerminal.logLine("No project detected"); - this.dbtTerminal.logLine("Can't proceed further without project"); + this.diagnosticsOutputChannel.logLine("No project detected"); + this.diagnosticsOutputChannel.logLine( + "Can't proceed further without project", + ); return; } - this.dbtTerminal.logNewLine(); + this.diagnosticsOutputChannel.logNewLine(); for (const project of projects) { try { - this.dbtTerminal.logHorizontalRule(); - this.dbtTerminal.logLine( + this.diagnosticsOutputChannel.logHorizontalRule(); + this.diagnosticsOutputChannel.logLine( `Printing information for ${project.getProjectName()}`, ); - this.dbtTerminal.logHorizontalRule(); + this.diagnosticsOutputChannel.logHorizontalRule(); await this.printProjectInfo(project); } catch (e) { - this.dbtTerminal.logNewLine(); - this.dbtTerminal.logLine( + this.diagnosticsOutputChannel.logNewLine(); + this.diagnosticsOutputChannel.logLine( "Failed to print all the info for the project...", ); - this.dbtTerminal.logLine(`Error=${e}`); + this.diagnosticsOutputChannel.logLine(`Error=${e}`); } finally { - this.dbtTerminal.logHorizontalRule(); + this.diagnosticsOutputChannel.logHorizontalRule(); } } - this.dbtTerminal.logNewLine(); - this.dbtTerminal.logLine("Diagnostics completed successfully..."); + this.diagnosticsOutputChannel.logNewLine(); + this.diagnosticsOutputChannel.logLine( + "Diagnostics completed successfully...", + ); } catch (e) { - this.dbtTerminal.logNewLine(); - this.dbtTerminal.logLine("Diagnostics ended with error..."); - this.dbtTerminal.logLine(`Error=${e}`); + this.diagnosticsOutputChannel.logNewLine(); + this.diagnosticsOutputChannel.logLine( + "Diagnostics ended with error...", + ); + this.diagnosticsOutputChannel.logLine(`Error=${e}`); } }), commands.registerCommand( @@ -873,23 +886,29 @@ export class VSCodeCommands implements Disposable { } private async printProjectInfo(project: DBTProject) { - this.dbtTerminal.logLine(`Project Name=${project.getProjectName()}`); - this.dbtTerminal.logLine(`Adapter Type=${project.getAdapterType()}`); + this.diagnosticsOutputChannel.logLine( + `Project Name=${project.getProjectName()}`, + ); + this.diagnosticsOutputChannel.logLine( + `Adapter Type=${project.getAdapterType()}`, + ); const dbtVersion = project.getDBTVersion(); if (!dbtVersion) { - this.dbtTerminal.logLine("DBT is not initialized properly"); + this.diagnosticsOutputChannel.logLine("DBT is not initialized properly"); } else { - this.dbtTerminal.logLine(`DBT version=${dbtVersion.join(".")}`); + this.diagnosticsOutputChannel.logLine( + `DBT version=${dbtVersion.join(".")}`, + ); } if (!project.getPythonBridgeStatus()) { - this.dbtTerminal.logLine("Python bridge is not connected"); + this.diagnosticsOutputChannel.logLine("Python bridge is not connected"); } else { - this.dbtTerminal.logLine("Python bridge is connected"); + this.diagnosticsOutputChannel.logLine("Python bridge is connected"); } - this.dbtTerminal.logNewLine(); + this.diagnosticsOutputChannel.logNewLine(); const paths = [ { @@ -919,7 +938,7 @@ export class VSCodeCommands implements Disposable { for (const p of paths) { if (!p.path) { - this.dbtTerminal.logLine(`${p.pathType} path not found`); + this.diagnosticsOutputChannel.logLine(`${p.pathType} path not found`); continue; } let line = `${p.pathType} path=${p.path}\t\t`; @@ -928,29 +947,29 @@ export class VSCodeCommands implements Disposable { } else { line += "File exists at location"; } - this.dbtTerminal.logLine(line); + this.diagnosticsOutputChannel.logLine(line); } const dbtProjectFilePath = project.getDBTProjectFilePath(); if (existsSync(dbtProjectFilePath)) { - this.dbtTerminal.logNewLine(); - this.dbtTerminal.logNewLine(); - this.dbtTerminal.logLine("dbt_project.yml"); - this.dbtTerminal.logHorizontalRule(); + this.diagnosticsOutputChannel.logNewLine(); + this.diagnosticsOutputChannel.logNewLine(); + this.diagnosticsOutputChannel.logLine("dbt_project.yml"); + this.diagnosticsOutputChannel.logHorizontalRule(); const fileContent = readFileSync(dbtProjectFilePath, "utf8"); - this.dbtTerminal.logLine(fileContent.replace(/\n/g, "\r\n")); - this.dbtTerminal.logHorizontalRule(); + this.diagnosticsOutputChannel.logLine(fileContent.replace(/\n/g, "\r\n")); + this.diagnosticsOutputChannel.logHorizontalRule(); } - this.dbtTerminal.logNewLine(); + this.diagnosticsOutputChannel.logNewLine(); const diagnostics = project.getAllDiagnostic(); - this.dbtTerminal.logLine( + this.diagnosticsOutputChannel.logLine( `Number of diagnostics issues=${diagnostics.length}`, ); for (const d of diagnostics) { - this.dbtTerminal.logLine(d.message); + this.diagnosticsOutputChannel.logLine(d.message); } - await project.debug(); + await project.debug(false); } private runSelectedQuery(uri: Uri, range: Range): void { diff --git a/src/commands/runModel.ts b/src/commands/runModel.ts index 3ca6fd736..b22bd3b39 100644 --- a/src/commands/runModel.ts +++ b/src/commands/runModel.ts @@ -1,12 +1,11 @@ import path = require("path"); +import { RunModelType } from "@altimateai/dbt-integration"; import { Uri, window } from "vscode"; import { GenerateModelFromSourceParams } from "../code_lens_provider/sourceModelCreationCodeLensProvider"; -import { RunModelType } from "../domain"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; import { NodeTreeItem } from "../treeview_provider/modelTreeviewProvider"; -import { extendErrorWithSupportLinks, provideSingleton } from "../utils"; +import { extendErrorWithSupportLinks } from "../utils"; -@provideSingleton(RunModel) export class RunModel { constructor(private dbtProjectContainer: DBTProjectContainer) {} diff --git a/src/commands/sqlToModel.ts b/src/commands/sqlToModel.ts index a6131fcc1..df2aed742 100644 --- a/src/commands/sqlToModel.ts +++ b/src/commands/sqlToModel.ts @@ -1,23 +1,26 @@ +import { DBTTerminal } from "@altimateai/dbt-integration"; +import { inject } from "inversify"; +import * as path from "path"; +import { Position, ProgressLocation, Range, window } from "vscode"; import { AltimateRequest } from "../altimate"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; import { ManifestCacheChangedEvent, ManifestCacheProjectAddedEvent, -} from "../manifest/event/manifestCacheChangedEvent"; +} from "../dbt_client/event/manifestCacheChangedEvent"; +import { AltimateAuthService } from "../services/altimateAuthService"; import { TelemetryService } from "../telemetry"; -import { extendErrorWithSupportLinks, provideSingleton } from "../utils"; -import { Position, ProgressLocation, Range, window } from "vscode"; -import * as path from "path"; +import { extendErrorWithSupportLinks } from "../utils"; -@provideSingleton(SqlToModel) export class SqlToModel { private eventMap: Map = new Map(); constructor( private dbtProjectContainer: DBTProjectContainer, private telemetry: TelemetryService, private altimate: AltimateRequest, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, + private altimateAuthService: AltimateAuthService, ) { dbtProjectContainer.onManifestChanged((event) => this.onManifestCacheChanged(event), @@ -34,7 +37,7 @@ export class SqlToModel { } async getModelFromSql() { - if (!this.altimate.handlePreviewFeatures()) { + if (!this.altimateAuthService.handlePreviewFeatures()) { return; } this.telemetry.sendTelemetryEvent("sqlToModel"); diff --git a/src/commands/tests/initCatalog.ts b/src/commands/tests/initCatalog.ts index ebdbb25cc..193726848 100644 --- a/src/commands/tests/initCatalog.ts +++ b/src/commands/tests/initCatalog.ts @@ -1,9 +1,7 @@ -import { Catalog } from "../../dbt_client/dbtIntegration"; -import { provideSingleton } from "../../utils"; +import { Catalog } from "@altimateai/dbt-integration"; import { ScanContext } from "./scanContext"; import { AltimateScanStep } from "./step"; -@provideSingleton(InitCatalog) export class InitCatalog implements AltimateScanStep { public async run(scanContext: ScanContext) { const project = scanContext.project; diff --git a/src/commands/tests/missingSchemaTest.ts b/src/commands/tests/missingSchemaTest.ts index 52ec7605c..e25965d94 100644 --- a/src/commands/tests/missingSchemaTest.ts +++ b/src/commands/tests/missingSchemaTest.ts @@ -1,10 +1,8 @@ +import { RESOURCE_TYPE_MODEL } from "@altimateai/dbt-integration"; import { Diagnostic, DiagnosticSeverity, Range } from "vscode"; import { ScanContext } from "./scanContext"; import { AltimateScanStep } from "./step"; -import { provideSingleton } from "../../utils"; -import { DBTProject } from "../../manifest/dbtProject"; -@provideSingleton(MissingSchemaTest) export class MissingSchemaTest implements AltimateScanStep { public async run(scanContext: ScanContext) { const { @@ -21,7 +19,7 @@ export class MissingSchemaTest implements AltimateScanStep { // blacklisting node types.. should we instead whitelist just models and sources? if ( // TODO - need to filter out only models here but the resource type isnt available - !value.uniqueId.startsWith(DBTProject.RESOURCE_TYPE_MODEL) || + !value.uniqueId.startsWith(RESOURCE_TYPE_MODEL) || value.config.materialized === "seed" || value.config.materialized === "ephemeral" ) { diff --git a/src/commands/tests/scanContext.ts b/src/commands/tests/scanContext.ts index e9930ea5a..21d3d2dbb 100644 --- a/src/commands/tests/scanContext.ts +++ b/src/commands/tests/scanContext.ts @@ -1,7 +1,7 @@ -import { Catalog } from "../../dbt_client/dbtIntegration"; -import { DBTProject } from "../../manifest/dbtProject"; -import { ManifestCacheProjectAddedEvent } from "../../manifest/event/manifestCacheChangedEvent"; +import { Catalog } from "@altimateai/dbt-integration"; import { Diagnostic } from "vscode"; +import { DBTProject } from "../../dbt_client/dbtProject"; +import { ManifestCacheProjectAddedEvent } from "../../dbt_client/event/manifestCacheChangedEvent"; export interface AltimateCatalog { [projectName: string]: { [key: string]: Catalog }; diff --git a/src/commands/tests/staleModelColumnTest.ts b/src/commands/tests/staleModelColumnTest.ts index 525777187..f22c7b321 100644 --- a/src/commands/tests/staleModelColumnTest.ts +++ b/src/commands/tests/staleModelColumnTest.ts @@ -1,15 +1,10 @@ +import { createFullPathForNode } from "@altimateai/dbt-integration"; +import { readFileSync } from "fs"; import { Diagnostic, DiagnosticSeverity, Range, Uri } from "vscode"; +import { getColumnNameByCase, removeProtocol } from "../../utils"; import { ScanContext } from "./scanContext"; import { AltimateScanStep } from "./step"; -import { readFileSync } from "fs"; -import { - getColumnNameByCase, - provideSingleton, - removeProtocol, -} from "../../utils"; -import { createFullPathForNode } from "../../manifest/parsers"; -@provideSingleton(StaleModelColumnTest) export class StaleModelColumnTest implements AltimateScanStep { private getTextLocation( modelname: string, diff --git a/src/commands/tests/undocumentedModelColumnTest.ts b/src/commands/tests/undocumentedModelColumnTest.ts index 5a5417760..a34cd4376 100644 --- a/src/commands/tests/undocumentedModelColumnTest.ts +++ b/src/commands/tests/undocumentedModelColumnTest.ts @@ -1,9 +1,8 @@ import { Diagnostic, DiagnosticSeverity, Range } from "vscode"; +import { getColumnNameByCase } from "../../utils"; import { ScanContext } from "./scanContext"; import { AltimateScanStep } from "./step"; -import { getColumnNameByCase, provideSingleton } from "../../utils"; -@provideSingleton(UndocumentedModelColumnTest) export class UndocumentedModelColumnTest implements AltimateScanStep { public async run(scanContext: ScanContext) { const { diff --git a/src/commands/tests/unmaterializedModelTest.ts b/src/commands/tests/unmaterializedModelTest.ts index 3a8fa00c6..be5a9d035 100644 --- a/src/commands/tests/unmaterializedModelTest.ts +++ b/src/commands/tests/unmaterializedModelTest.ts @@ -1,9 +1,7 @@ import { Diagnostic, DiagnosticSeverity, Range } from "vscode"; import { ScanContext } from "./scanContext"; import { AltimateScanStep } from "./step"; -import { provideSingleton } from "../../utils"; -@provideSingleton(UnmaterializedModelTest) export class UnmaterializedModelTest implements AltimateScanStep { public async run(scanContext: ScanContext) { const { diff --git a/src/commands/validateSql.ts b/src/commands/validateSql.ts index 544f447fd..18923c4f0 100644 --- a/src/commands/validateSql.ts +++ b/src/commands/validateSql.ts @@ -1,36 +1,32 @@ +import { DBTTerminal } from "@altimateai/dbt-integration"; +import { inject } from "inversify"; import { basename } from "path"; -import { AltimateRequest, ModelNode } from "../altimate"; -import { ColumnMetaData, NodeMetaData, SourceTable } from "../domain"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { - ManifestCacheChangedEvent, - ManifestCacheProjectAddedEvent, -} from "../manifest/event/manifestCacheChangedEvent"; -import { TelemetryService } from "../telemetry"; -import { extendErrorWithSupportLinks, provideSingleton } from "../utils"; +import { PythonException } from "python-bridge"; import { CancellationToken, - DiagnosticCollection, - ProgressLocation, - Uri, - ViewColumn, - window, -} from "vscode"; -import { DBTProject } from "../manifest/dbtProject"; -import { commands, Diagnostic, + DiagnosticCollection, DiagnosticSeverity, languages, Position, + ProgressLocation, Range, + Uri, + ViewColumn, + window, workspace, } from "vscode"; +import { AltimateRequest, ModelNode } from "../altimate"; import { SqlPreviewContentProvider } from "../content_provider/sqlPreviewContentProvider"; -import { PythonException } from "python-bridge"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { + ManifestCacheChangedEvent, + ManifestCacheProjectAddedEvent, +} from "../dbt_client/event/manifestCacheChangedEvent"; +import { TelemetryService } from "../telemetry"; +import { extendErrorWithSupportLinks } from "../utils"; -@provideSingleton(ValidateSql) export class ValidateSql { private eventMap: Map = new Map(); private diagnosticsCollection: DiagnosticCollection; @@ -38,6 +34,7 @@ export class ValidateSql { private dbtProjectContainer: DBTProjectContainer, private telemetry: TelemetryService, private altimate: AltimateRequest, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, ) { dbtProjectContainer.onManifestChanged((event) => @@ -119,6 +116,7 @@ export class ValidateSql { let relationsWithoutColumns: string[] = []; let compiledQuery: string | undefined; let cancellationToken: CancellationToken | undefined; + let abortController: AbortController | undefined; await window.withProgress( { location: ProgressLocation.Notification, @@ -128,6 +126,8 @@ export class ValidateSql { async (_, token) => { try { cancellationToken = token; + abortController = new AbortController(); + token.onCancellationRequested(() => abortController!.abort()); const fileContentBytes = await workspace.fs.readFile(currentFilePath); if (cancellationToken.isCancellationRequested) { return; @@ -156,9 +156,8 @@ export class ValidateSql { mappedNode, relationsWithoutColumns: _relationsWithoutColumns, } = await project.getNodesWithDBColumns( - event, modelsToFetch, - cancellationToken, + abortController!.signal, ); parentModels.push(...modelsToFetch.map((n) => mappedNode[n])); relationsWithoutColumns = _relationsWithoutColumns; diff --git a/src/commands/walkthroughCommands.ts b/src/commands/walkthroughCommands.ts index 6751b3226..d6184dd82 100644 --- a/src/commands/walkthroughCommands.ts +++ b/src/commands/walkthroughCommands.ts @@ -1,17 +1,20 @@ import { - window, - QuickPickItem, - ProgressLocation, + CommandProcessExecutionFactory, + DBTTerminal, +} from "@altimateai/dbt-integration"; +import { inject } from "inversify"; +import { commands, + ProgressLocation, + QuickPickItem, + window, workspace, } from "vscode"; -import { getFirstWorkspacePath, provideSingleton } from "../utils"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { TelemetryService } from "../telemetry"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { PythonEnvironment } from "../dbt_client/pythonEnvironment"; import { ProjectQuickPickItem } from "../quickpick/projectQuickPick"; -import { CommandProcessExecutionFactory } from "../commandProcessExecution"; -import { PythonEnvironment } from "../manifest/pythonEnvironment"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; +import { TelemetryService } from "../telemetry"; +import { getFirstWorkspacePath } from "../utils"; enum PromptAnswer { YES = "Yes", @@ -24,13 +27,14 @@ enum DbtInstallationPromptAnswer { INSTALL_FUSION = "Install dbt fusion", } -@provideSingleton(WalkthroughCommands) export class WalkthroughCommands { constructor( private dbtProjectContainer: DBTProjectContainer, private telemetry: TelemetryService, private commandProcessExecutionFactory: CommandProcessExecutionFactory, + @inject(PythonEnvironment) private pythonEnvironment: PythonEnvironment, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, ) {} @@ -68,8 +72,8 @@ export class WalkthroughCommands { return; } const runModelOutput = await project.debug(); - if (runModelOutput.includes("ERROR")) { - throw new Error(runModelOutput); + if (runModelOutput.fullOutput.includes("ERROR")) { + throw new Error(runModelOutput.fullOutput); } } catch (err) { this.dbtTerminal.error( diff --git a/src/comment_provider/conversationProvider.ts b/src/comment_provider/conversationProvider.ts index c866a11cb..2e6608259 100644 --- a/src/comment_provider/conversationProvider.ts +++ b/src/comment_provider/conversationProvider.ts @@ -1,32 +1,36 @@ +import { + DBTTerminal, + RESOURCE_TYPE_MACRO, + RESOURCE_TYPE_TEST, +} from "@altimateai/dbt-integration"; +import { inject } from "inversify"; import { CancellationToken, + commands, Comment, CommentAuthorInformation, CommentMode, CommentReply, + comments, CommentThread, CommentThreadState, Disposable, + env, MarkdownString, Range, TextDocument, Uri, - commands, - comments, - env, window, workspace, } from "vscode"; -import { extendErrorWithSupportLinks, provideSingleton } from "../utils"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; -import path = require("path"); +import { Conversation, ConversationGroup, SharedDoc } from "../altimate"; import { ConversationService } from "../services/conversationService"; +import { QueryManifestService } from "../services/queryManifestService"; import { SharedStateService } from "../services/sharedStateService"; import { UsersService } from "../services/usersService"; -import { QueryManifestService } from "../services/queryManifestService"; -import { DBTProject } from "../manifest/dbtProject"; -import { SharedDoc, ConversationGroup, Conversation } from "../altimate"; import { TelemetryService } from "../telemetry"; +import { extendErrorWithSupportLinks } from "../utils"; +import path = require("path"); // Extends vscode commentthread and add extra fields for reference export interface ConversationCommentThread extends CommentThread { @@ -54,7 +58,6 @@ export class ConversationComment implements Comment { } const ALLOWED_FILE_EXTENSIONS = [".sql"]; -@provideSingleton(ConversationProvider) export class ConversationProvider implements Disposable { private disposables: Disposable[] = []; private commentController; @@ -69,6 +72,7 @@ export class ConversationProvider implements Disposable { constructor( private conversationService: ConversationService, private usersService: UsersService, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, private emitterService: SharedStateService, private queryManifestService: QueryManifestService, @@ -447,7 +451,7 @@ export class ConversationProvider implements Disposable { // For macro if (macroNode) { return { - resource_type: DBTProject.RESOURCE_TYPE_MACRO, + resource_type: RESOURCE_TYPE_MACRO, uniqueId: macroNode.uniqueId, }; } @@ -456,7 +460,7 @@ export class ConversationProvider implements Disposable { // For tests if (testNode) { return { - resource_type: DBTProject.RESOURCE_TYPE_TEST, + resource_type: RESOURCE_TYPE_TEST, uniqueId: testNode.uniqueId, }; } diff --git a/src/comment_provider/index.ts b/src/comment_provider/index.ts index d7e692c09..21f27e3f4 100644 --- a/src/comment_provider/index.ts +++ b/src/comment_provider/index.ts @@ -1,9 +1,6 @@ -import { Disposable, languages } from "vscode"; -import { DBTPowerUserExtension } from "../dbtPowerUserExtension"; -import { provideSingleton } from "../utils"; +import { Disposable } from "vscode"; import { ConversationProvider } from "./conversationProvider"; -@provideSingleton(CommentProviders) export class CommentProviders implements Disposable { private disposables: Disposable[] = []; constructor(private conversationProvider: ConversationProvider) { diff --git a/src/constants.ts b/src/constants.ts deleted file mode 100644 index 927854390..000000000 --- a/src/constants.ts +++ /dev/null @@ -1,5 +0,0 @@ -export enum ManifestPathType { - EMPTY = "", - LOCAL = "local", - REMOTE = "remote", -} diff --git a/src/content_provider/index.ts b/src/content_provider/index.ts index 4bb332851..bd86b2436 100644 --- a/src/content_provider/index.ts +++ b/src/content_provider/index.ts @@ -1,8 +1,6 @@ import { Disposable, workspace } from "vscode"; -import { provideSingleton } from "../utils"; import { SqlPreviewContentProvider } from "./sqlPreviewContentProvider"; -@provideSingleton(ContentProviders) export class ContentProviders implements Disposable { private disposables: Disposable[] = []; diff --git a/src/content_provider/sqlPreviewContentProvider.ts b/src/content_provider/sqlPreviewContentProvider.ts index a4db912a0..65c3b4701 100644 --- a/src/content_provider/sqlPreviewContentProvider.ts +++ b/src/content_provider/sqlPreviewContentProvider.ts @@ -4,22 +4,18 @@ import { Event, EventEmitter, FileSystemWatcher, + ProgressLocation, RelativePattern, TextDocumentContentProvider, Uri, - workspace, window, - ProgressLocation, + workspace, } from "vscode"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { debounce, provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; import { TelemetryService } from "../telemetry"; -import { DeferToProdService } from "../services/deferToProdService"; -import { AltimateRequest } from "../altimate"; -import { ManifestPathType } from "../constants"; +import { debounce } from "../utils"; import path = require("path"); -@provideSingleton(SqlPreviewContentProvider) export class SqlPreviewContentProvider implements TextDocumentContentProvider, Disposable { @@ -32,8 +28,6 @@ export class SqlPreviewContentProvider constructor( private dbtProjectContainer: DBTProjectContainer, - private deferToProdService: DeferToProdService, - private altimateRequest: AltimateRequest, private telemetry: TelemetryService, ) { this.subscriptions = workspace.onDidCloseTextDocument((compilationDoc) => { @@ -93,21 +87,7 @@ export class SqlPreviewContentProvider } this.telemetry.sendTelemetryEvent("requestCompilation"); await project.refreshProjectConfig(); - const result = await project.unsafeCompileQuery(query, modelName); - const { manifestPathType } = - this.deferToProdService.getDeferConfigByProjectRoot( - project.projectRoot.fsPath, - ); - const dbtIntegrationMode = workspace - .getConfiguration("dbt") - .get("dbtIntegration", "core"); - if ( - dbtIntegrationMode.startsWith("core") && - manifestPathType === ManifestPathType.REMOTE - ) { - this.altimateRequest.sendDeferToProdEvent(ManifestPathType.REMOTE); - } - return result; + return await project.unsafeCompileQuery(query, modelName); } catch (error: any) { const errorMessage = (error as Error).message; window.showErrorMessage(`Error while compiling: ${errorMessage}`); diff --git a/src/dbtPowerUserExtension.ts b/src/dbtPowerUserExtension.ts index 1027c09c1..974f92571 100644 --- a/src/dbtPowerUserExtension.ts +++ b/src/dbtPowerUserExtension.ts @@ -1,29 +1,27 @@ -import { Disposable, ExtensionContext, commands, workspace } from "vscode"; +import { NotebookProviders } from "@lib"; +import { commands, Disposable, ExtensionContext, workspace } from "vscode"; import { AutocompletionProviders } from "./autocompletion_provider"; import { CodeLensProviders } from "./code_lens_provider"; import { VSCodeCommands } from "./commands"; +import { CommentProviders } from "./comment_provider"; import { ContentProviders } from "./content_provider"; +import { DBTProjectContainer } from "./dbt_client/dbtProjectContainer"; import { DefinitionProviders } from "./definition_provider"; import { DocumentFormattingEditProviders } from "./document_formatting_edit_provider"; -import { DBTProjectContainer } from "./manifest/dbtProjectContainer"; -import { StatusBars } from "./statusbar"; -import { TreeviewProviders } from "./treeview_provider"; -import { provideSingleton } from "./utils"; -import { WebviewViewProviders } from "./webview_provider"; -import { TelemetryService } from "./telemetry"; import { HoverProviders } from "./hover_provider"; +import { DbtPowerUserMcpServer } from "./mcp"; import { DbtPowerUserActionsCenter } from "./quickpick"; +import { StatusBars } from "./statusbar"; +import { TelemetryService } from "./telemetry"; +import { TreeviewProviders } from "./treeview_provider"; import { ValidationProvider } from "./validation_provider"; -import { CommentProviders } from "./comment_provider"; -import { NotebookProviders } from "@lib"; -import { DbtPowerUserMcpServer } from "./mcp"; +import { WebviewViewProviders } from "./webview_provider"; enum PromptAnswer { YES = "Yes", NO = "No", } -@provideSingleton(DBTPowerUserExtension) export class DBTPowerUserExtension implements Disposable { static DBT_SQL_SELECTOR = [ { language: "jinja-sql", scheme: "file" }, diff --git a/src/dbt_client/datapilot.ts b/src/dbt_client/datapilot.ts index 1ab6dcd54..b9183d4fb 100644 --- a/src/dbt_client/datapilot.ts +++ b/src/dbt_client/datapilot.ts @@ -1,15 +1,21 @@ -import { getFirstWorkspacePath, provideSingleton } from "../utils"; -import { PythonEnvironment } from "../manifest/pythonEnvironment"; -import { CommandProcessExecutionFactory } from "../commandProcessExecution"; -import { DBTTerminal } from "./dbtTerminal"; +import { + CommandProcessExecutionFactory, + DBTConfiguration, + DBTTerminal, +} from "@altimateai/dbt-integration"; +import { inject } from "inversify"; +import { PythonEnvironment } from "./pythonEnvironment"; -@provideSingleton(AltimateDatapilot) export class AltimateDatapilot { private packageName = "altimate-datapilot-cli"; constructor( + @inject(PythonEnvironment) private pythonEnvironment: PythonEnvironment, private commandProcessExecutionFactory: CommandProcessExecutionFactory, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, + @inject("DBTConfiguration") + private dbtConfiguration: DBTConfiguration, ) {} async checkIfAltimateDatapilotInstalled(): Promise { @@ -17,7 +23,7 @@ export class AltimateDatapilot { this.commandProcessExecutionFactory.createCommandProcessExecution({ command: this.pythonEnvironment.pythonPath, args: ["-c", "import datapilot;print(datapilot.__version__)"], - cwd: getFirstWorkspacePath(), + cwd: this.dbtConfiguration.getWorkingDirectory(), envVars: this.pythonEnvironment.environmentVariables, }); const { stdout, stderr } = await process.complete(); @@ -42,7 +48,7 @@ export class AltimateDatapilot { "install", `${this.packageName}==${datapilotVersion}`, ], - cwd: getFirstWorkspacePath(), + cwd: this.dbtConfiguration.getWorkingDirectory(), envVars: this.pythonEnvironment.environmentVariables, }) .completeWithTerminalOutput(); diff --git a/src/dbt_client/dbtCloudIntegration.ts b/src/dbt_client/dbtCloudIntegration.ts deleted file mode 100644 index 1add16aae..000000000 --- a/src/dbt_client/dbtCloudIntegration.ts +++ /dev/null @@ -1,1243 +0,0 @@ -import { - workspace, - Uri, - languages, - Disposable, - Range, - window, - CancellationTokenSource, - Diagnostic, - DiagnosticCollection, - DiagnosticSeverity, - CancellationToken, -} from "vscode"; -import { provideSingleton } from "../utils"; -import { - Catalog, - DBColumn, - DBTNode, - DBTCommand, - DBTCommandExecutionInfrastructure, - DBTCommandExecutionStrategy, - DBTCommandFactory, - DBTDetection, - DBTProjectDetection, - DBTProjectIntegration, - QueryExecution, - HealthcheckArgs, -} from "./dbtIntegration"; -import { CommandProcessExecutionFactory } from "../commandProcessExecution"; -import { PythonBridge } from "python-bridge"; -import { join, dirname } from "path"; -import { AltimateRequest, ValidateSqlParseErrorResponse } from "../altimate"; -import path = require("path"); -import { DBTProject } from "../manifest/dbtProject"; -import { TelemetryService } from "../telemetry"; -import { DBTTerminal } from "./dbtTerminal"; -import { PythonEnvironment } from "../manifest/pythonEnvironment"; -import { existsSync, readFileSync } from "fs"; -import { ValidationProvider } from "../validation_provider"; -import { DeferToProdService } from "../services/deferToProdService"; -import { ProjectHealthcheck } from "./dbtCoreIntegration"; -import semver = require("semver"); -import { NodeMetaData } from "../domain"; -import * as crypto from "crypto"; -import { parse } from "yaml"; - -export function getDBTPath( - pythonEnvironment: PythonEnvironment, - terminal: DBTTerminal, -): string { - if (pythonEnvironment.pythonPath) { - const allowedDbtPaths = ["dbt", "dbt.exe"]; - const dbtPath = allowedDbtPaths.find((path) => - existsSync(join(dirname(pythonEnvironment.pythonPath), path)), - ); - if (dbtPath) { - const dbtPythonPath = join( - dirname(pythonEnvironment.pythonPath), - dbtPath, - ); - terminal.debug("Found dbt path in Python bin directory:", dbtPythonPath); - return dbtPythonPath; - } - } - terminal.debug("Using default dbt path:", "dbt"); - return "dbt"; -} - -@provideSingleton(DBTCloudDetection) -export class DBTCloudDetection implements DBTDetection { - constructor( - protected commandProcessExecutionFactory: CommandProcessExecutionFactory, - protected pythonEnvironment: PythonEnvironment, - protected terminal: DBTTerminal, - ) {} - - async detectDBT(): Promise { - const dbtPath = getDBTPath(this.pythonEnvironment, this.terminal); - try { - this.terminal.debug("DBTCLIDetection", "Detecting dbt cloud cli"); - const checkDBTInstalledProcess = - this.commandProcessExecutionFactory.createCommandProcessExecution({ - command: dbtPath, - args: ["--version"], - cwd: this.getFirstWorkspacePath(), - }); - const { stdout, stderr } = await checkDBTInstalledProcess.complete(); - if (stderr) { - throw new Error(stderr); - } - if (stdout.includes("dbt Cloud CLI")) { - const regex = /dbt Cloud CLI - (\d*\.\d*\.\d*)/gm; - const matches = regex.exec(stdout); - if (matches?.length === 2) { - const minVersion = "0.37.6"; - const currentVersion = matches[1]; - if (semver.lt(currentVersion, minVersion)) { - window.showErrorMessage( - `This version of dbt Cloud is not supported. Please update to a dbt Cloud CLI version higher than ${minVersion}`, - ); - this.terminal.debug( - "DBTCLIDetectionFailed", - "dbt cloud cli was found but version is not supported. Detection command returned : " + - stdout, - ); - return true; - } - } - this.terminal.debug("DBTCLIDetectionSuccess", "dbt cloud cli detected"); - return true; - } else { - this.terminal.debug( - "DBTCLIDetectionFailed", - "dbt cloud cli was not found. Detection command returned : " + - stdout, - ); - } - } catch (error) { - this.terminal.warn( - "DBTCLIDetectionError", - "Detection failed with error : " + (error as Error).message, - ); - } - this.terminal.debug( - "DBTCLIDetectionFailed", - "dbt cloud cli was not found. Detection command returning false", - ); - return false; - } - - private getFirstWorkspacePath(): string { - // If we are executing python via a wrapper like Meltano, - // we need to execute it from a (any) project directory - // By default, Command execution is in an ext dir context - const folders = workspace.workspaceFolders; - if (folders) { - return folders[0].uri.fsPath; - } else { - // TODO: this shouldn't happen but we should make sure this is valid fallback - return Uri.file("./").fsPath; - } - } -} - -@provideSingleton(DBTCloudProjectDetection) -export class DBTCloudProjectDetection implements DBTProjectDetection { - constructor(private altimate: AltimateRequest) {} - - async discoverProjects(projectDirectories: Uri[]): Promise { - this.altimate.handlePreviewFeatures(); - const packagesInstallPaths = projectDirectories.map((projectDirectory) => - path.join(projectDirectory.fsPath, "dbt_packages"), - ); - const filteredProjectFiles = projectDirectories.filter((uri) => { - return !packagesInstallPaths.some((packageInstallPath) => { - return uri.fsPath.startsWith(packageInstallPath!); - }); - }); - if (filteredProjectFiles.length > 20) { - window.showWarningMessage( - `dbt Power User detected ${filteredProjectFiles.length} projects in your work space, this will negatively affect performance.`, - ); - } - return filteredProjectFiles; - } -} - -@provideSingleton(DBTCloudProjectIntegration) -export class DBTCloudProjectIntegration - implements DBTProjectIntegration, Disposable -{ - private static QUEUE_ALL = "all"; - protected targetPath?: string; - private version: number[] | undefined; - protected projectName: string = "unknown_" + crypto.randomUUID(); - private adapterType: string = "unknown"; - protected packagesInstallPath?: string; - protected modelPaths?: string[]; - protected seedPaths?: string[]; - protected macroPaths?: string[]; - private python: PythonBridge; - protected dbtPath: string = "dbt"; - private disposables: Disposable[] = []; - protected readonly rebuildManifestDiagnostics = - languages.createDiagnosticCollection("dbt"); - private readonly pythonBridgeDiagnostics = - languages.createDiagnosticCollection("dbt"); - protected rebuildManifestCancellationTokenSource: - | CancellationTokenSource - | undefined; - private pathsInitialized = false; - - constructor( - private executionInfrastructure: DBTCommandExecutionInfrastructure, - protected dbtCommandFactory: DBTCommandFactory, - protected cliDBTCommandExecutionStrategyFactory: ( - path: Uri, - dbtPath: string, - ) => DBTCommandExecutionStrategy, - protected telemetry: TelemetryService, - private pythonEnvironment: PythonEnvironment, - protected terminal: DBTTerminal, - private validationProvider: ValidationProvider, - private deferToProdService: DeferToProdService, - protected projectRoot: Uri, - private altimateRequest: AltimateRequest, - ) { - this.terminal.debug( - "DBTCloudProjectIntegration", - `Registering dbt cloud project at ${this.projectRoot}`, - ); - this.python = this.executionInfrastructure.createPythonBridge( - this.projectRoot.fsPath, - ); - this.executionInfrastructure.createQueue( - DBTCloudProjectIntegration.QUEUE_ALL, - ); - - this.disposables.push( - this.pythonEnvironment.onPythonEnvironmentChanged(() => { - this.python = this.executionInfrastructure.createPythonBridge( - this.projectRoot.fsPath, - ); - this.initializeProject(); - }), - this.rebuildManifestDiagnostics, - this.pythonBridgeDiagnostics, - ); - } - - async refreshProjectConfig(): Promise { - if (!this.pathsInitialized) { - // First time let,s block - await this.initializePaths(); - this.pathsInitialized = true; - } else { - this.initializePaths(); - } - if (!this.version) { - await this.findVersion(); - } - } - - async executeSQL( - query: string, - limit: number, - modelName: string, - ): Promise { - this.throwIfNotAuthenticated(); - this.throwBridgeErrorIfAvailable(); - const showCommand = this.dbtCloudCommand( - new DBTCommand("Running sql...", [ - "show", - "--log-level", - "debug", - "--inline", - query, - "--limit", - limit.toString(), - "--output", - "json", - "--log-format", - "json", - ]), - ); - const cancellationTokenSource = new CancellationTokenSource(); - showCommand.setToken(cancellationTokenSource.token); - return new QueryExecution( - async () => { - cancellationTokenSource.cancel(); - }, - async () => { - const { stdout, stderr } = await showCommand.execute( - cancellationTokenSource.token, - ); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - const parsedLines = stdout - .trim() - .split("\n") - .map((line) => JSON.parse(line.trim())); - const previewLine = parsedLines.filter( - (line) => - line.hasOwnProperty("data") && line.data.hasOwnProperty("preview"), - ); - const compiledSqlLines = parsedLines.filter( - (line) => - line.hasOwnProperty("data") && line.data.hasOwnProperty("sql"), - ); - if (previewLine.length === 0) { - throw new Error("Could not find previewLine in " + stdout); - } - const preview = JSON.parse(previewLine[0].data.preview); - if (compiledSqlLines.length === 0) { - throw new Error("Could not find compiledSqlLine in " + stdout); - } - const compiledSql = - compiledSqlLines[compiledSqlLines.length - 1].data.sql; - return { - table: { - column_names: preview.length > 0 ? Object.keys(preview[0]) : [], - column_types: - preview.length > 0 - ? Object.keys(preview[0]).map((obj: any) => "string") - : [], - rows: preview.map((obj: any) => Object.values(obj)), - }, - compiled_sql: compiledSql, - raw_sql: query, - modelName, - }; - }, - ); - } - - async initializeProject(): Promise { - try { - await this.python.ex`from dbt_cloud_integration import *`; - await this.python.ex`from dbt_healthcheck import *`; - } catch (error) { - this.terminal.error( - "dbtCloudIntegration", - "Could not initalize Python environemnt", - error, - ); - window.showErrorMessage( - "Error occurred while initializing Python environment: " + error, - ); - } - this.dbtPath = getDBTPath(this.pythonEnvironment, this.terminal); - } - - async setSelectedTarget(targetName: string): Promise { - throw new Error("Method not implemented."); - } - - async getTargetNames(): Promise> { - throw new Error("Method not implemented."); - } - - getSelectedTarget(): string | undefined { - throw new Error("Method not implemented."); - } - - getTargetPath(): string | undefined { - return this.targetPath; - } - - getModelPaths(): string[] | undefined { - return this.modelPaths; - } - - getSeedPaths(): string[] | undefined { - return this.seedPaths; - } - - getMacroPaths(): string[] | undefined { - return this.macroPaths; - } - - getPackageInstallPath(): string | undefined { - return this.packagesInstallPath; - } - - getAdapterType(): string { - return this.adapterType; - } - - getVersion(): number[] { - return this.version || [0, 0, 0]; - } - - getProjectName(): string { - return this.projectName; - } - - getPythonBridgeStatus(): boolean { - return this.python.connected; - } - - // Handled by dbt cloud itself - async cleanupConnections(): Promise {} - - getAllDiagnostic(): Diagnostic[] { - return [ - ...(this.pythonBridgeDiagnostics.get(this.projectRoot) || []), - ...(this.rebuildManifestDiagnostics.get(this.projectRoot) || []), - ]; - } - - async rebuildManifest(): Promise { - // TODO: check whether we should allow parsing for unauthenticated users - // this.throwIfNotAuthenticated(); - if (this.rebuildManifestCancellationTokenSource) { - this.rebuildManifestCancellationTokenSource.cancel(); - this.rebuildManifestCancellationTokenSource = undefined; - } - const command = this.dbtCloudCommand( - this.dbtCommandFactory.createParseCommand(), - ); - command.addArgument("--log-format"); - command.addArgument("json"); - command.downloadArtifacts = true; - this.rebuildManifestCancellationTokenSource = new CancellationTokenSource(); - command.setToken(this.rebuildManifestCancellationTokenSource.token); - - try { - const result = await command.execute(); - const stderr = result.stderr; - // sending stderr everytime to verify in logs whether is coming as empty or not. - this.telemetry.sendTelemetryEvent("dbtCloudParseProjectUserError", { - error: stderr, - adapter: this.getAdapterType() || "unknown", - }); - this.terminal.info( - "dbtCloudParseProject", - "dbt cloud cli response", - false, - { - command: command.getCommandAsString(), - stderr, - }, - ); - const errorsAndWarnings = stderr - .trim() - .split("\n") - .map((line) => line.trim()) - .filter((line) => Boolean(line)) - .map((line) => - this.parseJSON( - "RebuildManifestErrorsAndWarningsJSONParsing", - line, - false, - ), - ); - const errors = errorsAndWarnings - .filter( - (line) => - line && - line.hasOwnProperty("info") && - line.info.hasOwnProperty("level") && - line.info.hasOwnProperty("msg") && - ["error", "fatal"].includes(line.info.level), - ) - .map((line) => line.info.msg); - const warnings = errorsAndWarnings - .filter( - (line) => - line && - line.hasOwnProperty("info") && - line.info.hasOwnProperty("level") && - line.info.hasOwnProperty("msg") && - line.info.level === "warn", - ) - .map((line) => line.info.msg); - this.rebuildManifestDiagnostics.clear(); - const diagnostics: Array = errors - .map( - (error) => - new Diagnostic( - new Range(0, 0, 999, 999), - error, - DiagnosticSeverity.Error, - ), - ) - .concat( - warnings.map( - (warning) => - new Diagnostic( - new Range(0, 0, 999, 999), - warning, - DiagnosticSeverity.Warning, - ), - ), - ); - if (diagnostics) { - // user error - this.rebuildManifestDiagnostics.set( - Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE), - diagnostics, - ); - } - } catch (error) { - this.telemetry.sendTelemetryError( - "dbtCloudCannotParseProjectCommandExecuteError", - error, - { - adapter: this.getAdapterType() || "unknown", - command: command.getCommandAsString(), - }, - ); - this.rebuildManifestDiagnostics.set( - Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE), - [ - new Diagnostic( - new Range(0, 0, 999, 999), - "Unable to parse dbt cloud cli response. If the problem persists please reach out to us: " + - error, - DiagnosticSeverity.Error, - ), - ], - ); - } - } - - async runModel(command: DBTCommand) { - return this.addCommandToQueue( - DBTCloudProjectIntegration.QUEUE_ALL, - await this.addDeferParams(this.dbtCloudCommand(command)), - ); - } - - async buildModel(command: DBTCommand) { - return this.addCommandToQueue( - DBTCloudProjectIntegration.QUEUE_ALL, - await this.addDeferParams(this.dbtCloudCommand(command)), - ); - } - - async buildProject(command: DBTCommand) { - return this.addCommandToQueue( - DBTCloudProjectIntegration.QUEUE_ALL, - await this.addDeferParams(this.dbtCloudCommand(command)), - ); - } - - async runTest(command: DBTCommand) { - return this.addCommandToQueue( - DBTCloudProjectIntegration.QUEUE_ALL, - await this.addDeferParams(this.dbtCloudCommand(command)), - ); - } - - async runModelTest(command: DBTCommand) { - return this.addCommandToQueue( - DBTCloudProjectIntegration.QUEUE_ALL, - await this.addDeferParams(this.dbtCloudCommand(command)), - ); - } - - async compileModel(command: DBTCommand) { - this.addCommandToQueue( - DBTCloudProjectIntegration.QUEUE_ALL, - await this.addDeferParams(this.dbtCloudCommand(command)), - ); - } - - async generateDocs(command: DBTCommand) { - this.addCommandToQueue( - DBTCloudProjectIntegration.QUEUE_ALL, - this.dbtCloudCommand(command), - ); - } - - async clean(command: DBTCommand): Promise { - this.throwIfNotAuthenticated(); - const { stdout, stderr } = await this.dbtCloudCommand(command).execute(); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return stdout; - } - - async executeCommandImmediately(command: DBTCommand) { - return await this.dbtCloudCommand(command).execute(); - } - - async deps(command: DBTCommand): Promise { - const { stdout, stderr } = await this.dbtCloudCommand(command).execute(); - if (stderr) { - throw new Error(stderr); - } - return stdout; - } - - async debug(command: DBTCommand): Promise { - command.args = ["environment", "show"]; - const { stdout, stderr } = await this.dbtCloudCommand(command).execute(); - if (stderr) { - throw new Error(stderr); - } - return stdout; - } - - private async getDeferParams(): Promise { - this.throwIfNotAuthenticated(); - const deferConfig = this.deferToProdService.getDeferConfigByProjectRoot( - this.projectRoot.fsPath, - ); - const { deferToProduction } = deferConfig; - // explicitly checking false to make sure defer is disabled - if (!deferToProduction) { - this.terminal.debug("Defer to Prod", "defer to prod not enabled"); - return ["--no-defer"]; - } - return []; - } - - private async addDeferParams(command: DBTCommand) { - const deferParams = await this.getDeferParams(); - deferParams.forEach((param) => command.addArgument(param)); - return command; - } - - protected dbtCloudCommand(command: DBTCommand) { - command.setExecutionStrategy( - this.cliDBTCommandExecutionStrategyFactory( - this.projectRoot, - this.dbtPath, - ), - ); - command.addArgument("--source"); - command.addArgument("dbt-power-user"); - const currentVersion = this.getVersion() - .map((part) => new String(part)) - .join("."); - const downloadArtifactsVersion = "0.37.20"; - if (semver.gte(currentVersion, downloadArtifactsVersion)) { - if (command.downloadArtifacts) { - command.addArgument("--download-artifacts"); - } - } - return command; - } - - private addCommandToQueue(queueName: string, command: DBTCommand) { - try { - this.throwIfNotAuthenticated(); - return this.executionInfrastructure.addCommandToQueue(queueName, command); - } catch (e) { - window.showErrorMessage((e as Error).message); - } - } - - // internal commands - async unsafeCompileNode(modelName: string): Promise { - this.throwIfNotAuthenticated(); - this.throwBridgeErrorIfAvailable(); - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Compiling model...", [ - "compile", - "--model", - modelName, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => JSON.parse(line.trim())) - .filter((line) => line.data?.hasOwnProperty("compiled")); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return compiledLine[0].data.compiled; - } - - async unsafeCompileQuery(query: string): Promise { - this.throwIfNotAuthenticated(); - this.throwBridgeErrorIfAvailable(); - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Compiling sql...", [ - "compile", - "--inline", - query, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => JSON.parse(line.trim())) - .filter((line) => line.data?.hasOwnProperty("compiled")); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return compiledLine[0].data.compiled; - } - - async validateSql( - query: string, - dialect: string, - models: any, - ): Promise { - this.throwIfNotAuthenticated(); - this.throwBridgeErrorIfAvailable(); - const result = await this.python?.lock( - (python) => - python!`to_dict(validate_sql(${query}, ${dialect}, ${models}))`, - ); - return result; - } - - async validateSQLDryRun(query: string): Promise<{ bytes_processed: string }> { - this.throwIfNotAuthenticated(); - this.throwBridgeErrorIfAvailable(); - const validateSqlCommand = this.dbtCloudCommand( - new DBTCommand("Estimating BigQuery cost...", [ - "compile", - "--inline", - `{{ validate_sql('${query}') }}`, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await validateSqlCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => JSON.parse(line.trim())) - .filter((line) => line.data?.hasOwnProperty("compiled")); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return JSON.parse(compiledLine[0].data.compiled); - } - - async getColumnsOfSource( - sourceName: string, - tableName: string, - ): Promise { - this.throwIfNotAuthenticated(); - this.throwBridgeErrorIfAvailable(); - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Getting columns of source...", [ - "compile", - "--inline", - `{% set output = [] %}{% for result in adapter.get_columns_in_relation(source('${sourceName}', '${tableName}')) %} {% do output.append({"column": result.name, "dtype": result.dtype}) %} {% endfor %} {{ tojson(output) }}`, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => JSON.parse(line.trim())) - .filter((line) => line.data?.hasOwnProperty("compiled")); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return JSON.parse(compiledLine[0].data.compiled); - } - - async getColumnsOfModel(modelName: string): Promise { - this.throwIfNotAuthenticated(); - this.throwBridgeErrorIfAvailable(); - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Getting columns of model...", [ - "compile", - "--inline", - `{% set output = [] %}{% for result in adapter.get_columns_in_relation(ref('${modelName}')) %} {% do output.append({"column": result.name, "dtype": result.dtype}) %} {% endfor %} {{ tojson(output) }}`, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => JSON.parse(line.trim())) - .filter((line) => line.data?.hasOwnProperty("compiled")); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return JSON.parse(compiledLine[0].data.compiled); - } - - async validateWhetherSqlHasColumns( - sql: string, - dialect: string, - ): Promise { - this.throwBridgeErrorIfAvailable(); - return this.python?.lock( - (python) => - python!`to_dict(validate_whether_sql_has_columns(${sql}, ${dialect}))`, - ); - } - - async fetchSqlglotSchema(sql: string, dialect: string): Promise { - this.throwBridgeErrorIfAvailable(); - return this.python?.lock( - (python) => python!`to_dict(fetch_schema_from_sql(${sql}, ${dialect}))`, - ); - } - - async getBulkCompiledSQL(models: NodeMetaData[]) { - const downloadArtifactsVersion = "0.37.20"; - const currentVersion = this.getVersion() - .map((part) => new String(part)) - .join("."); - if (semver.gte(currentVersion, downloadArtifactsVersion)) { - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Getting catalog...", [ - "compile", - "--download-artifacts", - "--model", - `"${models.map((item) => item.name).join(" ")}"`, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stderr } = await compileQueryCommand.execute( - new CancellationTokenSource().token, - ); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - } - const result: Record = {}; - for (const node of models) { - try { - // compiled sql file exists - const fileContentBytes = await workspace.fs.readFile( - Uri.file(node.compiled_path), - ); - const query = fileContentBytes.toString(); - result[node.uniqueId] = query; - continue; - } catch (e) { - this.terminal.error( - "getBulkCompiledSQL", - `Unable to find compiled sql file for model ${node.uniqueId}`, - e, - true, - ); - } - - try { - // compiled sql file doesn't exists or dbt below 0.37.20 - result[node.uniqueId] = await this.unsafeCompileNode(node.name); - } catch (e) { - this.terminal.error( - "getBulkCompiledSQL", - `Unable to compile sql for model ${node.uniqueId}`, - e, - true, - ); - } - } - return result; - } - - async getBulkSchemaFromDB( - nodes: DBTNode[], - cancellationToken: CancellationToken, - ): Promise> { - if (nodes.length === 0) { - return {}; - } - this.throwIfNotAuthenticated(); - this.throwBridgeErrorIfAvailable(); - const bulkModelQuery = ` -{% set result = {} %} -{% for n in ${JSON.stringify(nodes)} %} - {% set columns = adapter.get_columns_in_relation(ref(n["name"])) %} - {% set new_columns = [] %} - {% for column in columns %} - {% do new_columns.append({"column": column.name, "dtype": column.dtype}) %} - {% endfor %} - {% do result.update({n["unique_id"]:new_columns}) %} -{% endfor %} -{% for n in graph.sources.values() %} - {% set columns = adapter.get_columns_in_relation(source(n["source_name"], n["identifier"])) %} - {% set new_columns = [] %} - {% for column in columns %} - {% do new_columns.append({"column": column.name, "dtype": column.dtype}) %} - {% endfor %} - {% do result.update({n["unique_id"]:new_columns}) %} -{% endfor %} -{{ tojson(result) }}`; - console.log(bulkModelQuery); - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Getting catalog...", [ - "compile", - "--inline", - bulkModelQuery.trim().split("\n").join(""), - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = - await compileQueryCommand.execute(cancellationToken); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => JSON.parse(line.trim())) - .filter((line) => line.data?.hasOwnProperty("compiled")); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return JSON.parse(compiledLine[0].data.compiled); - } - - async getCatalog(): Promise { - this.throwIfNotAuthenticated(); - this.throwBridgeErrorIfAvailable(); - const bulkModelQuery = ` -{% set result = [] %} -{% for n in graph.nodes.values() %} - {% if n.resource_type == "test" or - n.resource_type == "analysis" or - n.resource_type == "sql_operation" or - n.config.materialized == "ephemeral" %} - {% continue %} - {% endif %} - {% set columns = adapter.get_columns_in_relation(ref(n["name"])) %} - {% for column in columns %} - {% do result.append({ - "table_database": n.database, - "table_schema": n.schema, - "table_name": n.name, - "column_name": column.name, - "column_type": column.dtype, - }) %} - {% endfor %} -{% endfor %} -{% for n in graph.sources.values() %} - {% set columns = adapter.get_columns_in_relation(source(n["source_name"], n["identifier"])) %} - {% for column in columns %} - {% do result.append({ - "table_database": n.database, - "table_schema": n.schema, - "table_name": n.name, - "column_name": column.name, - "column_type": column.dtype, - }) %} - {% endfor %} -{% endfor %} -{{ tojson(result) }}`; - - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Getting catalog...", [ - "compile", - "--inline", - bulkModelQuery, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => JSON.parse(line.trim())) - .filter( - (line) => - line.hasOwnProperty("data") && line.data?.hasOwnProperty("compiled"), - ); - if (compiledLine.length === 0) { - throw new Error("Could not get bulk schema from response: " + stdout); - } - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - const result: Catalog = JSON.parse(compiledLine[0].data.compiled); - return result; - } - - getDebounceForRebuildManifest() { - return 500; - } - - // get dbt config - protected async initializePaths() { - const packagePathsCommand = this.dbtCloudCommand( - new DBTCommand("Getting paths...", [ - "environment", - "show", - "--project-paths", - ]), - ); - try { - const { stdout, stderr } = await packagePathsCommand.execute(); - if (stderr) { - this.terminal.warn( - "DbtCloudIntegrationInitializePathsStdError", - "packaging paths command returns warning, ignoring", - true, - stderr, - ); - } - const lookupValue = (lookupString: string) => { - const regexString = `${lookupString}\\s*(.*)`; - const regexp = new RegExp(regexString, "gm"); - const matches = regexp.exec(stdout); - if (matches?.length === 2) { - return matches[1]; - } - throw new Error(`Could not find any entries for ${lookupString}`); - }; - const lookupEntries = (lookupString: string) => { - const regexString = `${lookupString}\\s*\\[(.*)\\]`; - const regexp = new RegExp(regexString, "gm"); - const matches = regexp.exec(stdout); - if (matches?.length === 2) { - return matches[1].split(",").map((m) => m.slice(1, -1)); - } - throw new Error(`Could not find any entries for ${lookupString}`); - }; - this.targetPath = join(this.projectRoot.fsPath, "target"); - this.modelPaths = lookupEntries("Model paths").map((p) => - join(this.projectRoot.fsPath, p), - ); - this.seedPaths = lookupEntries("Seed paths").map((p) => - join(this.projectRoot.fsPath, p), - ); - this.macroPaths = lookupEntries("Macro paths").map((p) => - join(this.projectRoot.fsPath, p), - ); - this.packagesInstallPath = join(this.projectRoot.fsPath, "dbt_packages"); - this.adapterType = lookupValue("Connection type"); - } catch (error) { - this.terminal.warn( - "DbtCloudIntegrationInitializePathsExceptionError", - "dbt environment show not returning required info, ignoring", - true, - error, - ); - this.targetPath = join(this.projectRoot.fsPath, "target"); - this.modelPaths = [join(this.projectRoot.fsPath, "models")]; - this.seedPaths = [join(this.projectRoot.fsPath, "seeds")]; - this.macroPaths = [join(this.projectRoot.fsPath, "macros")]; - this.packagesInstallPath = join(this.projectRoot.fsPath, "dbt_packages"); - } - - try { - const projectConfig = DBTProject.readAndParseProjectConfig( - this.projectRoot, - ); - this.projectName = projectConfig.name; - } catch (error) { - this.terminal.warn( - "DbtCloudIntegrationProjectNameFromConfigExceptionError", - "project name could not be read from dbt_project.yml, ignoring", - true, - error, - ); - } - } - - private async findVersion() { - try { - const versionCommand = this.dbtCloudCommand( - new DBTCommand("Getting version...", ["--version"]), - ); - const { stdout } = await versionCommand.execute(); - if (stdout.includes("dbt Cloud CLI")) { - const regex = /dbt Cloud CLI - (\d*\.\d*\.\d*)/gm; - const matches = regex.exec(stdout); - if (matches?.length === 2) { - this.version = matches[1].split(".").map((part) => parseInt(part)); - } else { - this.terminal.debug( - "DBTCLIDetectionFailed", - "dbt cloud cli was not found. Detection command returned : " + - stdout, - ); - } - } - } catch (error) { - this.terminal.warn( - "findVersion", - "Version lookup failed with error : " + (error as Error).message, - ); - } - } - - protected processJSONErrors(jsonErrors: string) { - if (!jsonErrors) { - return; - } - try { - const errorLines: string[] = []; - // eslint-disable-next-line prefer-spread - errorLines.push.apply( - errorLines, - jsonErrors - .trim() - .split("\n") - .map((line) => JSON.parse(line.trim())) - .filter( - (line) => - line.info.level === "error" || line.info.level === "fatal", - ) - .map((line) => line.info.msg), - ); - if (errorLines.length) { - return new Error(errorLines.join(", ")); - } - } catch (error) { - // ideally we never come here, this is a bug in our code - return new Error("Could not process " + jsonErrors + ": " + error); - } - } - - private throwIfNotAuthenticated() { - this.validationProvider.throwIfNotAuthenticated(); - } - - async dispose() { - try { - await this.executionInfrastructure.closePythonBridge(this.python); - } catch (error) {} // We don't care about errors here. - this.rebuildManifestDiagnostics.clear(); - this.pythonBridgeDiagnostics.clear(); - while (this.disposables.length) { - const x = this.disposables.pop(); - if (x) { - x.dispose(); - } - } - } - - private getYamlContent(uri: Uri): string | undefined { - try { - return readFileSync(uri.fsPath, "utf-8"); - } catch (error) { - this.terminal.error( - "getYamlContent", - "Error occured while reading file: " + uri.fsPath, - error, - ); - return undefined; - } - } - - findPackageVersion(packageName: string) { - const packagesYmlPath = Uri.joinPath(this.projectRoot, "packages.yml"); - const dependenciesYmlPath = Uri.joinPath( - this.projectRoot, - "dependencies.yml", - ); - - const fileContents = - this.getYamlContent(packagesYmlPath) || - this.getYamlContent(dependenciesYmlPath); - if (!fileContents) { - return undefined; - } - - const packages = parse(fileContents) as - | { packages: { package: string; version: string }[] } - | undefined; - if (packages?.packages?.length) { - const packageObject = packages.packages.find( - (p) => p.package.indexOf(packageName) > -1, - ); - return packageObject?.version as string; - } - return undefined; - } - - private throwBridgeErrorIfAvailable() { - const allDiagnostics: DiagnosticCollection[] = [ - this.pythonBridgeDiagnostics, - this.rebuildManifestDiagnostics, - ]; - - for (const diagnosticCollection of allDiagnostics) { - for (const [_, diagnostics] of diagnosticCollection) { - const error = diagnostics.find( - (diagnostic) => diagnostic.severity === DiagnosticSeverity.Error, - ); - if (error) { - throw new Error(error.message); - } - } - } - } - - async performDatapilotHealthcheck({ - manifestPath, - catalogPath, - config, - configPath, - }: HealthcheckArgs): Promise { - this.throwBridgeErrorIfAvailable(); - const result = await this.python?.lock( - (python) => - python!`to_dict(project_healthcheck(${manifestPath}, ${catalogPath}, ${configPath}, ${config}, ${this.altimateRequest.getAIKey()}, ${this.altimateRequest.getInstanceName()}, ${AltimateRequest.ALTIMATE_URL}))`, - ); - return result; - } - - async applyDeferConfig(): Promise {} - - async applySelectedTarget(): Promise {} - - throwDiagnosticsErrorIfAvailable(): void { - this.throwBridgeErrorIfAvailable(); - } - - protected parseJSON( - contextName: string, - json: string, - throw_: boolean = true, - ): any { - try { - return JSON.parse(json); - } catch (error) { - this.terminal.error( - "dbtCloud" + contextName + "Error", - "An error occured while parsing following json: " + json, - error, - ); - if (throw_) { - throw error; - } - } - } -} diff --git a/src/dbt_client/dbtCoreCommandIntegration.ts b/src/dbt_client/dbtCoreCommandIntegration.ts deleted file mode 100644 index da48cda91..000000000 --- a/src/dbt_client/dbtCoreCommandIntegration.ts +++ /dev/null @@ -1,462 +0,0 @@ -import { CancellationToken, CancellationTokenSource } from "vscode"; -import { provideSingleton } from "../utils"; -import { - DBTCoreDetection, - DBTCoreProjectDetection, - DBTCoreProjectIntegration, -} from "./dbtCoreIntegration"; -import { - QueryExecution, - DBTCommand, - DBColumn, - Catalog, - DBTNode, -} from "./dbtIntegration"; -import { getDBTPath } from "./dbtCloudIntegration"; - -// TODO: either fix this class or remove it -@provideSingleton(DBTCoreCommandDetection) -export class DBTCoreCommandDetection extends DBTCoreDetection {} - -// TODO: either fix this class or remove it -@provideSingleton(DBTCoreCommandProjectDetection) -export class DBTCoreCommandProjectDetection extends DBTCoreProjectDetection {} - -@provideSingleton(DBTCoreProjectIntegration) -export class DBTCoreCommandProjectIntegration extends DBTCoreProjectIntegration { - private dbtPath = "dbt"; - - refreshProjectConfig(): Promise { - this.dbtPath = getDBTPath(this.pythonEnvironment, this.dbtTerminal); - return super.refreshProjectConfig(); - } - - async executeSQL( - query: string, - limit: number, - modelName: string, - ): Promise { - this.throwBridgeErrorIfAvailable(); - const showCommand = this.dbtCoreCommand( - new DBTCommand("Running sql...", [ - "show", - "--log-level", - "debug", - "--inline", - query, - "--limit", - limit.toString(), - "--output", - "json", - "--log-format", - "json", - ]), - ); - const cancellationTokenSource = new CancellationTokenSource(); - showCommand.setToken(cancellationTokenSource.token); - return new QueryExecution( - async () => { - cancellationTokenSource.cancel(); - }, - async () => { - const { stdout, stderr } = await showCommand.execute( - cancellationTokenSource.token, - ); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - const parsedLines = stdout - .trim() - .split("\n") - .map((line) => { - try { - return JSON.parse(line.trim()); - } catch (err) {} - }); - const previewLine = parsedLines.filter( - (line) => - line && - line.hasOwnProperty("data") && - line.data.hasOwnProperty("preview"), - ); - const compiledSqlLines = parsedLines.filter( - (line) => - line && - line.hasOwnProperty("data") && - line.data.hasOwnProperty("sql"), - ); - const preview = JSON.parse(previewLine[0].data.preview); - const compiledSql = - compiledSqlLines[compiledSqlLines.length - 1].data.sql; - return { - table: { - column_names: preview.length > 0 ? Object.keys(preview[0]) : [], - column_types: - preview.length > 0 - ? Object.keys(preview[0]).map((obj: any) => "string") - : [], - rows: preview.map((obj: any) => Object.values(obj)), - }, - compiled_sql: compiledSql, - raw_sql: query, - modelName, - }; - }, - ); - } - - protected dbtCoreCommand(command: DBTCommand) { - const newCommand = super.dbtCoreCommand(command); - newCommand.setExecutionStrategy( - this.cliDBTCommandExecutionStrategyFactory( - this.projectRoot, - this.dbtPath, - ), - ); - return newCommand; - } - - // internal commands - async unsafeCompileNode(modelName: string): Promise { - this.throwBridgeErrorIfAvailable(); - const compileQueryCommand = this.dbtCoreCommand( - new DBTCommand("Compiling model...", [ - "compile", - "--model", - modelName, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => { - try { - return JSON.parse(line.trim()); - } catch (err) {} - }) - .filter( - (line) => - line && - line.hasOwnProperty("data") && - line.data?.hasOwnProperty("compiled"), - ); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return compiledLine[0].data.compiled; - } - - async unsafeCompileQuery(query: string): Promise { - this.throwBridgeErrorIfAvailable(); - const compileQueryCommand = this.dbtCoreCommand( - new DBTCommand("Compiling sql...", [ - "compile", - "--inline", - query, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => { - try { - return JSON.parse(line.trim()); - } catch (err) {} - }) - .filter( - (line) => - line && - line.hasOwnProperty("data") && - line.data?.hasOwnProperty("compiled"), - ); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return compiledLine[0].data.compiled; - } - - async validateSQLDryRun(query: string): Promise<{ bytes_processed: string }> { - this.throwBridgeErrorIfAvailable(); - const validateSqlCommand = this.dbtCoreCommand( - new DBTCommand("Estimating BigQuery cost...", [ - "compile", - "--inline", - `{{ validate_sql('${query}') }}`, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await validateSqlCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => { - try { - return JSON.parse(line.trim()); - } catch (err) {} - }) - .filter( - (line) => - line && - line.hasOwnProperty("data") && - line.data?.hasOwnProperty("compiled"), - ); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return JSON.parse(compiledLine[0].data.compiled); - } - - async getColumnsOfSource( - sourceName: string, - tableName: string, - ): Promise { - this.throwBridgeErrorIfAvailable(); - const compileQueryCommand = this.dbtCoreCommand( - new DBTCommand("Getting columns of source...", [ - "compile", - "--inline", - `{% set output = [] %}{% for result in adapter.get_columns_in_relation(source('${sourceName}', '${tableName}')) %} {% do output.append({"column": result.name, "dtype": result.dtype}) %} {% endfor %} {{ tojson(output) }}`, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => { - try { - return JSON.parse(line.trim()); - } catch (err) {} - }) - .filter( - (line) => - line && - line.hasOwnProperty("data") && - line.data?.hasOwnProperty("compiled"), - ); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return JSON.parse(compiledLine[0].data.compiled); - } - - async getColumnsOfModel(modelName: string): Promise { - this.throwBridgeErrorIfAvailable(); - const compileQueryCommand = this.dbtCoreCommand( - new DBTCommand("Getting columns of model...", [ - "compile", - "--inline", - `{% set output = [] %}{% for result in adapter.get_columns_in_relation(ref('${modelName}')) %} {% do output.append({"column": result.name, "dtype": result.dtype}) %} {% endfor %} {{ tojson(output) }}`, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => { - try { - return JSON.parse(line.trim()); - } catch (err) {} - }) - .filter( - (line) => - line && - line.hasOwnProperty("data") && - line.data?.hasOwnProperty("compiled"), - ); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return JSON.parse(compiledLine[0].data.compiled); - } - - private processJSONErrors(jsonErrors: string) { - if (!jsonErrors) { - return; - } - try { - const errorLines: string[] = []; - // eslint-disable-next-line prefer-spread - errorLines.push.apply( - errorLines, - jsonErrors - .trim() - .split("\n") - .map((line) => { - try { - return JSON.parse(line.trim()); - } catch (err) {} - }) - .filter( - (line) => - line && - line.hasOwnProperty("info") && - line.info.hasOwnProperty("level") && - (line.info.level === "error" || line.info.level === "fatal"), - ) - .map((line) => line.info.msg), - ); - if (errorLines.length) { - return new Error(errorLines.join(", ")); - } - } catch (error) { - // ideally we never come here, this is a bug in our code - return new Error("Could not process " + jsonErrors + ": " + error); - } - } - - async getCatalog(): Promise { - this.throwBridgeErrorIfAvailable(); - const bulkModelQuery = ` -{% set result = [] %} -{% for n in graph.nodes.values() %} - {% if n.resource_type == "test" or - n.resource_type == "analysis" or - n.resource_type == "sql_operation" or - n.config.materialized == "ephemeral" %} - {% continue %} - {% endif %} - {% set columns = adapter.get_columns_in_relation(ref(n["name"])) %} - {% for column in columns %} - {% do result.append({ - "table_database": n.database, - "table_schema": n.schema, - "table_name": n.name, - "column_name": column.name, - "column_type": column.dtype, - }) %} - {% endfor %} -{% endfor %} -{% for n in graph.sources.values() %} - {% set columns = adapter.get_columns_in_relation(source(n["source_name"], n["identifier"])) %} - {% for column in columns %} - {% do result.append({ - "table_database": n.database, - "table_schema": n.schema, - "table_name": n.name, - "column_name": column.name, - "column_type": column.dtype, - }) %} - {% endfor %} -{% endfor %} -{{ tojson(result) }}`; - - const compileQueryCommand = this.dbtCoreCommand( - new DBTCommand("Getting catalog...", [ - "compile", - "--inline", - bulkModelQuery.trim().split("\n").join(""), - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => { - try { - return JSON.parse(line.trim()); - } catch (err) {} - }) - .filter( - (line) => - line && - line.hasOwnProperty("data") && - line.data?.hasOwnProperty("compiled"), - ); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - const result: Catalog = JSON.parse(compiledLine[0].data.compiled); - return result; - } - - async getBulkSchemaFromDB( - nodes: DBTNode[], - cancellationToken: CancellationToken, - ): Promise> { - if (nodes.length === 0) { - return {}; - } - this.throwBridgeErrorIfAvailable(); - const bulkModelQuery = ` - {% set result = {} %} - {% for n in ${JSON.stringify(nodes)} %} - {% set columns = adapter.get_columns_in_relation(ref(n["name"])) %} - {% set new_columns = [] %} - {% for column in columns %} - {% do new_columns.append({"column": column.name, "dtype": column.dtype}) %} - {% endfor %} - {% do result.update({n["unique_id"]:new_columns}) %} - {% endfor %} - {% for n in graph.sources.values() %} - {% set columns = adapter.get_columns_in_relation(source(n["source_name"], n["identifier"])) %} - {% set new_columns = [] %} - {% for column in columns %} - {% do new_columns.append({"column": column.name, "dtype": column.dtype}) %} - {% endfor %} - {% do result.update({n["unique_id"]:new_columns}) %} - {% endfor %} - {{ tojson(result) }}`; - console.log(bulkModelQuery); - const compileQueryCommand = this.dbtCoreCommand( - new DBTCommand("Getting catalog...", [ - "compile", - "--inline", - bulkModelQuery, - "--output", - "json", - "--log-format", - "json", - ]), - ); - const { stdout, stderr } = - await compileQueryCommand.execute(cancellationToken); - const compiledLine = stdout - .trim() - .split("\n") - .map((line) => JSON.parse(line.trim())) - .filter( - (line) => - line.hasOwnProperty("data") && line.data?.hasOwnProperty("compiled"), - ); - if (compiledLine.length === 0) { - throw new Error("Could not get bulk schema from response: " + stdout); - } - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - return JSON.parse(compiledLine[0].data.compiled); - } -} diff --git a/src/dbt_client/dbtCoreIntegration.ts b/src/dbt_client/dbtCoreIntegration.ts deleted file mode 100644 index 8658e9f53..000000000 --- a/src/dbt_client/dbtCoreIntegration.ts +++ /dev/null @@ -1,1252 +0,0 @@ -import { - CancellationToken, - Diagnostic, - DiagnosticCollection, - DiagnosticSeverity, - Disposable, - languages, - Range, - RelativePattern, - Uri, - window, - workspace, -} from "vscode"; -import { - extendErrorWithSupportLinks, - getFirstWorkspacePath, - getProjectRelativePath, - provideSingleton, - setupWatcherHandler, -} from "../utils"; -import { - Catalog, - CompilationResult, - DBColumn, - DBTNode, - DBTCommand, - DBTCommandExecutionInfrastructure, - DBTDetection, - DBTProjectDetection, - DBTProjectIntegration, - ExecuteSQLResult, - PythonDBTCommandExecutionStrategy, - QueryExecution, - SourceNode, - Node, - ExecuteSQLError, - HealthcheckArgs, - CLIDBTCommandExecutionStrategy, - DBTCommandExecutionStrategy, -} from "./dbtIntegration"; -import { PythonEnvironment } from "../manifest/pythonEnvironment"; -import { CommandProcessExecutionFactory } from "../commandProcessExecution"; -import { PythonBridge, PythonException } from "python-bridge"; -import * as path from "path"; -import { DBTProject } from "../manifest/dbtProject"; -import { existsSync, readFileSync } from "fs"; -import { parse } from "yaml"; -import { TelemetryService } from "../telemetry"; -import { - AltimateRequest, - NotFoundError, - ValidateSqlParseErrorResponse, -} from "../altimate"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ManifestPathType } from "../constants"; -import { DBTTerminal } from "./dbtTerminal"; -import { ValidationProvider } from "../validation_provider"; -import { DeferToProdService } from "../services/deferToProdService"; -import { NodeMetaData } from "../domain"; -import * as crypto from "crypto"; - -const DEFAULT_QUERY_TEMPLATE = "select * from ({query}) as query limit {limit}"; - -// TODO: we shouold really get these from manifest directly -interface ResolveReferenceNodeResult { - database: string; - schema: string; - alias: string; -} - -interface ResolveReferenceSourceResult { - database: string; - schema: string; - alias: string; - resource_type: string; - identifier: string; -} - -interface DeferConfig { - deferToProduction: boolean; - favorState: boolean; - manifestPathForDeferral: string; - manifestPathType?: ManifestPathType; - dbtCoreIntegrationId?: number; -} - -type InsightType = "Modelling" | "Test" | "structure"; - -interface Insight { - name: string; - type: InsightType; - message: string; - recommendation: string; - reason_to_flag: string; - metadata: { - model?: string; - model_unique_id?: string; - model_type?: string; - convention?: string | null; - }; -} - -type Severity = "ERROR" | "WARNING"; - -interface ModelInsight { - insight: Insight; - severity: Severity; - unique_id: string; - package_name: string; - path: string; - original_file_path: string; -} - -export interface ProjectHealthcheck { - model_insights: Record; - // package_insights: any; -} - -@provideSingleton(DBTCoreDetection) -export class DBTCoreDetection implements DBTDetection { - constructor( - private pythonEnvironment: PythonEnvironment, - private commandProcessExecutionFactory: CommandProcessExecutionFactory, - ) {} - - async detectDBT(): Promise { - try { - const checkDBTInstalledProcess = - this.commandProcessExecutionFactory.createCommandProcessExecution({ - command: this.pythonEnvironment.pythonPath, - args: ["-c", "import dbt"], - cwd: getFirstWorkspacePath(), - envVars: this.pythonEnvironment.environmentVariables, - }); - const { stderr } = await checkDBTInstalledProcess.complete(); - if (stderr) { - throw new Error(stderr); - } - return true; - } catch (error) { - return false; - } - } -} - -@provideSingleton(DBTCoreProjectDetection) -export class DBTCoreProjectDetection - implements DBTProjectDetection, Disposable -{ - constructor( - private executionInfrastructure: DBTCommandExecutionInfrastructure, - private dbtTerminal: DBTTerminal, - ) {} - - private getPackageInstallPathFallback( - projectDirectory: Uri, - packageInstallPath: string, - ): string { - const dbtProjectFile = path.join( - projectDirectory.fsPath, - "dbt_project.yml", - ); - if (existsSync(dbtProjectFile)) { - const dbtProjectConfig: any = parse(readFileSync(dbtProjectFile, "utf8")); - const packagesInstallPath = dbtProjectConfig["packages-install-path"]; - if (packagesInstallPath) { - if (path.isAbsolute(packagesInstallPath)) { - return packagesInstallPath; - } else { - return path.join(projectDirectory.fsPath, packagesInstallPath); - } - } - } - return packageInstallPath; - } - - async discoverProjects(projectDirectories: Uri[]): Promise { - let packagesInstallPaths = projectDirectories.map((projectDirectory) => - path.join(projectDirectory.fsPath, "dbt_packages"), - ); - let python: PythonBridge | undefined; - try { - python = this.executionInfrastructure.createPythonBridge( - getFirstWorkspacePath(), - ); - - await python.ex`from dbt_core_integration import *`; - const packagesInstallPathsFromPython = await python.lock( - (python) => - python`to_dict(find_package_paths(${projectDirectories.map( - (projectDirectory) => projectDirectory.fsPath, - )}))`, - ); - packagesInstallPaths = packagesInstallPaths.map( - (packageInstallPath, index) => { - const packageInstallPathFromPython = - packagesInstallPathsFromPython[index]; - if (packageInstallPathFromPython) { - return Uri.file(packageInstallPathFromPython).fsPath; - } - return packageInstallPath; - }, - ); - } catch (error) { - this.dbtTerminal.debug( - "dbtCoreIntegration:discoverProjects", - "An error occured while finding package paths: " + error, - ); - // Fallback to reading yaml files - packagesInstallPaths = projectDirectories.map((projectDirectory, idx) => - this.getPackageInstallPathFallback( - projectDirectory, - packagesInstallPaths[idx], - ), - ); - } finally { - if (python) { - this.executionInfrastructure.closePythonBridge(python); - } - } - - const filteredProjectFiles = projectDirectories.filter((uri) => { - return !packagesInstallPaths.some((packageInstallPath) => { - return uri.fsPath.startsWith(packageInstallPath!); - }); - }); - if (filteredProjectFiles.length > 20) { - window.showWarningMessage( - `dbt Power User detected ${filteredProjectFiles.length} projects in your work space, this will negatively affect performance.`, - ); - } - return filteredProjectFiles; - } - - async dispose() {} -} - -@provideSingleton(DBTCoreProjectIntegration) -export class DBTCoreProjectIntegration - implements DBTProjectIntegration, Disposable -{ - static DBT_PROFILES_FILE = "profiles.yml"; - - private profilesDir?: string; - private targetPath?: string; - private adapterType?: string; - private targetName?: string; - private version?: number[]; - private projectName: string = "unknown_" + crypto.randomUUID(); - private packagesInstallPath?: string; - private modelPaths?: string[]; - private seedPaths?: string[]; - private macroPaths?: string[]; - protected python: PythonBridge; - private disposables: Disposable[] = []; - private readonly rebuildManifestDiagnostics = - languages.createDiagnosticCollection("dbt"); - private readonly pythonBridgeDiagnostics = - languages.createDiagnosticCollection("dbt"); - private static QUEUE_ALL = "all"; - - constructor( - private executionInfrastructure: DBTCommandExecutionInfrastructure, - protected pythonEnvironment: PythonEnvironment, - private telemetry: TelemetryService, - private pythonDBTCommandExecutionStrategy: PythonDBTCommandExecutionStrategy, - protected cliDBTCommandExecutionStrategyFactory: ( - path: Uri, - dbtPath: string, - ) => DBTCommandExecutionStrategy, - private dbtProjectContainer: DBTProjectContainer, - private altimateRequest: AltimateRequest, - protected dbtTerminal: DBTTerminal, - private validationProvider: ValidationProvider, - private deferToProdService: DeferToProdService, - protected projectRoot: Uri, - private projectConfigDiagnostics: DiagnosticCollection, - ) { - this.dbtTerminal.debug( - "DBTCoreProjectIntegration", - `Registering dbt core project at ${this.projectRoot}`, - ); - this.python = this.executionInfrastructure.createPythonBridge( - this.projectRoot.fsPath, - ); - this.executionInfrastructure.createQueue( - DBTCoreProjectIntegration.QUEUE_ALL, - ); - - this.disposables.push( - this.pythonEnvironment.onPythonEnvironmentChanged(() => { - this.python = this.executionInfrastructure.createPythonBridge( - this.projectRoot.fsPath, - ); - }), - this.rebuildManifestDiagnostics, - this.pythonBridgeDiagnostics, - ); - - this.isDbtLoomInstalled().then((isInstalled) => { - this.telemetry.setTelemetryCustomAttribute( - "dbtLoomInstalled", - `${isInstalled}`, - ); - }); - } - - private async isDbtLoomInstalled(): Promise { - try { - await this.python.ex`from dbt_loom import *`; - return true; - } catch (error) { - return false; - } - } - - // remove the trailing slashes if they exists, - // causes the quote to be escaped when passing to python - private removeTrailingSlashes(input: string | undefined) { - return input?.replace(/\\+$/, ""); - } - - private getLimitQuery(queryTemplate: string, query: string, limit: number) { - return queryTemplate - .replace("{query}", () => query) - .replace("{limit}", () => limit.toString()); - } - - private async getQuery( - query: string, - limit: number, - ): Promise<{ queryTemplate: string; limitQuery: string }> { - try { - const dbtVersion = await this.version; - //dbt supports limit macro after v1.5 - if (dbtVersion && dbtVersion[0] >= 1 && dbtVersion[1] >= 5) { - const args = { compiled_code: query, limit }; - const queryTemplateFromMacro = await this.python?.lock( - (python) => - python!`to_dict(project.execute_macro('get_show_sql', ${args}, ${query}))`, - ); - - this.dbtTerminal.debug( - "DBTCoreProjectIntegration", - "Using query template from macro", - queryTemplateFromMacro, - ); - return { - queryTemplate: queryTemplateFromMacro, - limitQuery: queryTemplateFromMacro, - }; - } - } catch (err) { - console.error("Error while getting get_show_sql macro", err); - this.telemetry.sendTelemetryError( - "executeMacroGetLimitSubquerySQLError", - err, - { adapter: this.adapterType || "unknown" }, - ); - } - - const queryTemplate = workspace - .getConfiguration("dbt") - .get("queryTemplate"); - - if (queryTemplate && queryTemplate !== DEFAULT_QUERY_TEMPLATE) { - console.log("Using user provided query template", queryTemplate); - const limitQuery = this.getLimitQuery(queryTemplate, query, limit); - - return { queryTemplate, limitQuery }; - } - - return { - queryTemplate: DEFAULT_QUERY_TEMPLATE, - limitQuery: this.getLimitQuery(DEFAULT_QUERY_TEMPLATE, query, limit), - }; - } - - async refreshProjectConfig(): Promise { - await this.createPythonDbtProject(this.python); - await this.python.ex`project.init_project()`; - this.targetName = await this.findSelectedTarget(); - this.targetPath = await this.findTargetPath(); - this.modelPaths = await this.findModelPaths(); - this.seedPaths = await this.findSeedPaths(); - this.macroPaths = await this.findMacroPaths(); - this.packagesInstallPath = await this.findPackagesInstallPath(); - this.version = await this.findVersion(); - this.projectName = await this.findProjectName(); - this.adapterType = await this.findAdapterType(); - } - - async findSelectedTarget(): Promise { - return await this.python.lock( - (python) => python`to_dict(project.config.target_name)`, - ); - } - - async setSelectedTarget(targetName: string): Promise { - await this.python.lock( - (python) => python`project.set_selected_target(${targetName})`, - ); - await this.refreshProjectConfig(); - } - - async getTargetNames(): Promise> { - return await this.python.lock( - (python) => python`to_dict(project.get_target_names())`, - ); - } - - async executeSQL( - query: string, - limit: number, - modelName: string, - ): Promise { - this.throwBridgeErrorIfAvailable(); - const { limitQuery } = await this.getQuery(query, limit); - - const queryThread = this.executionInfrastructure.createPythonBridge( - this.projectRoot.fsPath, - ); - return new QueryExecution( - async () => { - queryThread.kill(2); - }, - async () => { - await this.createPythonDbtProject(queryThread); - await queryThread.ex`project.init_project()`; - let result: ExecuteSQLResult; - // compile query - const compiledQuery = await this.unsafeCompileQuery( - limitQuery, - modelName, - ); - try { - // execute query - result = await queryThread!.lock( - (python) => python`to_dict(project.execute_sql(${compiledQuery}))`, - ); - const { manifestPathType } = - this.deferToProdService.getDeferConfigByProjectRoot( - this.projectRoot.fsPath, - ); - if (manifestPathType === ManifestPathType.REMOTE) { - this.altimateRequest.sendDeferToProdEvent(ManifestPathType.REMOTE); - } - } catch (err) { - const message = `Error while executing sql: ${compiledQuery}`; - this.dbtTerminal.error("dbtCore:executeSQL", message, err); - if (err instanceof PythonException) { - throw new ExecuteSQLError(err.exception.message, compiledQuery!); - } - throw new ExecuteSQLError((err as Error).message, compiledQuery!); - } finally { - await this.cleanupConnections(); - await queryThread.end(); - } - return { ...result, compiled_stmt: compiledQuery, modelName }; - }, - ); - } - - private async createPythonDbtProject(bridge: PythonBridge) { - await bridge.ex`from dbt_core_integration import *`; - const targetPath = this.removeTrailingSlashes( - await bridge.lock( - (python) => python`target_path(${this.projectRoot.fsPath})`, - ), - ); - const { deferToProduction, manifestPath, favorState } = - await this.getDeferConfig(); - await bridge.ex`project = DbtProject(project_dir=${this.projectRoot.fsPath}, profiles_dir=${this.profilesDir}, target_path=${targetPath}, defer_to_prod=${deferToProduction}, manifest_path=${manifestPath}, favor_state=${favorState}) if 'project' not in locals() else project`; - } - - async initializeProject(): Promise { - try { - await this.python - .ex`from dbt_core_integration import default_profiles_dir`; - await this.python.ex`from dbt_healthcheck import *`; - this.profilesDir = this.removeTrailingSlashes( - await this.python.lock( - (python) => python`default_profiles_dir(${this.projectRoot.fsPath})`, - ), - ); - if (this.profilesDir) { - const dbtProfileWatcher = workspace.createFileSystemWatcher( - new RelativePattern( - this.profilesDir, - DBTCoreProjectIntegration.DBT_PROFILES_FILE, - ), - ); - this.disposables.push( - dbtProfileWatcher, - // when the project config changes we need to re-init the dbt project - ...setupWatcherHandler(dbtProfileWatcher, () => - this.rebuildManifest(), - ), - ); - } - await this.createPythonDbtProject(this.python); - this.pythonBridgeDiagnostics.clear(); - } catch (exc: any) { - if (exc instanceof PythonException) { - // python errors can be about anything, so we just associate the error with the project file - // with a fixed range - if (exc.message.includes("No module named 'dbt'")) { - // Let's not create an error for each project if dbt is not detected - // This is already displayed in the status bar - return; - } - let errorMessage = - "An error occured while initializing the dbt project: " + - exc.exception.message; - if (exc.exception.type.module === "dbt.exceptions") { - // TODO: we can do provide solutions per type of dbt exception - errorMessage = - "An error occured while initializing the dbt project, dbt found following issue: " + - exc.exception.message; - } - this.pythonBridgeDiagnostics.set( - Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE), - [new Diagnostic(new Range(0, 0, 999, 999), errorMessage)], - ); - this.telemetry.sendTelemetryError("pythonBridgeInitPythonError", exc); - } else { - window.showErrorMessage( - extendErrorWithSupportLinks( - "An unexpected error occured while initializing the dbt project at " + - this.projectRoot + - ": " + - exc + - ".", - ), - ); - this.telemetry.sendTelemetryError("pythonBridgeInitError", exc); - } - } - } - - getSelectedTarget() { - return this.targetName; - } - - getTargetPath(): string | undefined { - return this.targetPath; - } - - getModelPaths(): string[] | undefined { - return this.modelPaths; - } - - getSeedPaths(): string[] | undefined { - return this.seedPaths; - } - - getMacroPaths(): string[] | undefined { - return this.macroPaths; - } - - getPackageInstallPath(): string | undefined { - return this.packagesInstallPath; - } - - getAdapterType(): string | undefined { - return this.adapterType; - } - - getVersion(): number[] | undefined { - return this.version; - } - - getProjectName(): string { - return this.projectName; - } - - async findAdapterType(): Promise { - return this.python.lock( - (python) => python`project.config.credentials.type`, - ); - } - - getPythonBridgeStatus(): boolean { - return this.python.connected; - } - - async cleanupConnections(): Promise { - try { - await this.python.ex`project.cleanup_connections()`; - } catch (exc) { - if (exc instanceof PythonException) { - this.telemetry.sendTelemetryEvent( - "pythonBridgeCleanupConnectionsError", - { - error: exc.exception.message, - adapter: this.getAdapterType() || "unknown", // TODO: this should be moved to dbtProject - }, - ); - } - this.telemetry.sendTelemetryEvent( - "pythonBridgeCleanupConnectionsUnexpectedError", - { - error: (exc as Error).message, - adapter: this.getAdapterType() || "unknown", // TODO: this should be moved to dbtProject - }, - ); - } - } - - getAllDiagnostic(): Diagnostic[] { - const projectURI = Uri.joinPath( - this.projectRoot, - DBTProject.DBT_PROJECT_FILE, - ); - return [ - ...(this.pythonBridgeDiagnostics.get(projectURI) || []), - ...(this.projectConfigDiagnostics.get(projectURI) || []), - ...(this.rebuildManifestDiagnostics.get(projectURI) || []), - ]; - } - - async rebuildManifest(): Promise { - const errors = this.projectConfigDiagnostics.get( - Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE), - ); - if (errors !== undefined && errors.length > 0) { - // No point in trying to rebuild the manifest if the config is not valid - return; - } - try { - await this.python.lock( - (python) => python`to_dict(project.safe_parse_project())`, - ); - this.rebuildManifestDiagnostics.clear(); - } catch (exc) { - if (exc instanceof PythonException) { - // dbt errors can be about anything, so we just associate the error with the project file - // with a fixed range - this.rebuildManifestDiagnostics.set( - Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE), - [ - new Diagnostic( - new Range(0, 0, 999, 999), - "There is a problem in your dbt project. Compilation failed: " + - exc.exception.message, - ), - ], - ); - this.telemetry.sendTelemetryEvent( - "pythonBridgeCannotParseProjectUserError", - { - error: exc.exception.message, - adapter: this.getAdapterType() || "unknown", // TODO: this should be moved to dbtProject - }, - ); - return; - } - // if we get here, it is not a dbt error but an extension error. - this.telemetry.sendTelemetryError( - "pythonBridgeCannotParseProjectUnknownError", - exc, - { - adapter: this.adapterType || "unknown", // TODO: this should be moved to dbtProject - }, - ); - window.showErrorMessage( - extendErrorWithSupportLinks( - "An error occured while rebuilding the dbt manifest: " + exc + ".", - ), - ); - } - } - - async runModel(command: DBTCommand) { - return this.addCommandToQueue( - await this.addDeferParams(this.dbtCoreCommand(command)), - ); - } - - async buildModel(command: DBTCommand) { - return this.addCommandToQueue( - await this.addDeferParams(this.dbtCoreCommand(command)), - ); - } - - async buildProject(command: DBTCommand) { - return this.addCommandToQueue( - await this.addDeferParams(this.dbtCoreCommand(command)), - ); - } - - async runTest(command: DBTCommand) { - return this.addCommandToQueue( - await this.addDeferParams(this.dbtCoreCommand(command)), - ); - } - - async runModelTest(command: DBTCommand) { - return this.addCommandToQueue( - await this.addDeferParams(this.dbtCoreCommand(command)), - ); - } - - async compileModel(command: DBTCommand) { - this.addCommandToQueue( - await this.addDeferParams(this.dbtCoreCommand(command)), - ); - } - - async generateDocs(command: DBTCommand) { - this.addCommandToQueue(this.dbtCoreCommand(command)); - } - - async clean(command: DBTCommand) { - const { stdout, stderr } = await this.dbtCoreCommand(command).execute(); - if (stderr) { - throw new Error(stderr); - } - return stdout; - } - - async executeCommandImmediately(command: DBTCommand) { - return await this.dbtCoreCommand(command).execute(); - } - - async deps(command: DBTCommand) { - const { stdout, stderr } = await this.dbtCoreCommand(command).execute(); - if (stderr) { - throw new Error(stderr); - } - return stdout; - } - - async debug(command: DBTCommand) { - const { stdout, stderr } = await this.dbtCoreCommand(command).execute(); - if (stderr) { - throw new Error(stderr); - } - return stdout; - } - - private addCommandToQueue(command: DBTCommand) { - const isInstalled = - this.dbtProjectContainer.showErrorIfDbtOrPythonNotInstalled(); - if (!isInstalled) { - return; - } - return this.executionInfrastructure.addCommandToQueue( - DBTCoreProjectIntegration.QUEUE_ALL, - command, - ); - } - - private async getDeferManifestPath( - manifestPathType: ManifestPathType | undefined, - manifestPathForDeferral: string, - dbtCoreIntegrationId: number | undefined, - ): Promise { - if (!manifestPathType) { - const configNotPresent = new Error( - "Please configure defer to production functionality by specifying manifest path in Actions panel before using it.", - ); - throw configNotPresent; - } - if (manifestPathType === ManifestPathType.LOCAL) { - if (!manifestPathForDeferral) { - const configNotPresent = new Error( - "manifestPathForDeferral config is not present, use the actions panel to set the Defer to production configuration.", - ); - this.dbtTerminal.error( - "manifestPathForDeferral", - "manifestPathForDeferral is not present", - configNotPresent, - ); - throw configNotPresent; - } - return manifestPathForDeferral; - } - if (manifestPathType === ManifestPathType.REMOTE) { - try { - this.validationProvider.throwIfNotAuthenticated(); - } catch (err) { - throw new Error( - "Defer to production is currently enabled with 'DataPilot dbt integration' mode. It requires a valid Altimate AI API key and instance name in the settings. In order to run dbt commands, please either switch to Local Path mode or disable the feature or add an API key / instance name.", - ); - } - - this.dbtTerminal.debug( - "remoteManifest", - `fetching artifact url for dbtCoreIntegrationId: ${dbtCoreIntegrationId}`, - ); - try { - const response = await this.altimateRequest.fetchArtifactUrl( - "manifest", - dbtCoreIntegrationId!, - ); - const manifestPath = await this.altimateRequest.downloadFileLocally( - response.url, - this.projectRoot, - ); - console.log(`Set remote manifest path: ${manifestPath}`); - return manifestPath; - } catch (error) { - if (error instanceof NotFoundError) { - const manifestNotFoundError = new Error( - "Unable to download remote manifest file. Did you upload your manifest using the Altimate DataPilot CLI?", - ); - this.dbtTerminal.error( - "remoteManifestError", - "Unable to download remote manifest file.", - manifestNotFoundError, - ); - throw manifestNotFoundError; - } - throw error; - } - } - throw new Error(`Invalid manifestPathType: ${manifestPathType}`); - } - - private async getDeferParams(): Promise { - const deferConfig = this.deferToProdService.getDeferConfigByProjectRoot( - this.projectRoot.fsPath, - ); - const { - deferToProduction, - manifestPathForDeferral, - favorState, - manifestPathType, - dbtCoreIntegrationId, - } = deferConfig; - if (!deferToProduction) { - this.dbtTerminal.debug("deferToProd", "defer to prod not enabled"); - return []; - } - const manifestPath = await this.getDeferManifestPath( - manifestPathType, - manifestPathForDeferral, - dbtCoreIntegrationId, - ); - const args = ["--defer", "--state", manifestPath]; - if (favorState) { - args.push("--favor-state"); - } - this.dbtTerminal.debug( - "deferToProd", - `executing dbt command with defer params ${manifestPathType} mode`, - true, - args, - ); - - if (manifestPathType === ManifestPathType.REMOTE) { - this.altimateRequest.sendDeferToProdEvent(ManifestPathType.REMOTE); - } - return args; - } - - private async addDeferParams(command: DBTCommand) { - const deferParams = await this.getDeferParams(); - deferParams.forEach((param) => command.addArgument(param)); - return command; - } - - protected dbtCoreCommand(command: DBTCommand) { - command.addArgument("--project-dir"); - command.addArgument(this.projectRoot.fsPath); - if (this.profilesDir) { - command.addArgument("--profiles-dir"); - command.addArgument(this.profilesDir); - } - if (this.targetName) { - command.addArgument("--target"); - command.addArgument(this.targetName); - } - command.setExecutionStrategy(this.pythonDBTCommandExecutionStrategy); - return command; - } - - // internal commands - async unsafeCompileNode(modelName: string): Promise { - this.throwBridgeErrorIfAvailable(); - const output = await this.python?.lock( - (python) => - python!`to_dict(project.compile_node(project.get_ref_node(${modelName})))`, - ); - return output.compiled_sql; - } - - async unsafeCompileQuery( - query: string, - originalModelName: string | undefined = undefined, - ): Promise { - this.throwBridgeErrorIfAvailable(); - const output = await this.python?.lock( - (python) => - python!`to_dict(project.compile_sql(${query}, ${originalModelName}))`, - ); - return output.compiled_sql; - } - - async validateSql(query: string, dialect: string, models: any) { - this.throwBridgeErrorIfAvailable(); - const result = await this.python?.lock( - (python) => - python!`to_dict(validate_sql(${query}, ${dialect}, ${models}))`, - ); - return result; - } - - async validateSQLDryRun(query: string) { - this.throwBridgeErrorIfAvailable(); - const result = await this.python?.lock<{ bytes_processed: string }>( - (python) => python!`to_dict(project.validate_sql_dry_run(${query}))`, - ); - return result; - } - - async getColumnsOfModel(modelName: string) { - this.throwBridgeErrorIfAvailable(); - // Get database and schema - const node = (await this.python?.lock( - (python) => python!`to_dict(project.get_ref_node(${modelName}))`, - )) as ResolveReferenceNodeResult; - // Get columns - if (!node) { - return []; - } - // TODO: fix this type - return this.getColumsOfRelation( - node.database, - node.schema, - node.alias || modelName, - ); - } - - async getColumnsOfSource(sourceName: string, tableName: string) { - this.throwBridgeErrorIfAvailable(); - // Get database and schema - const node = (await this.python?.lock( - (python) => - python!`to_dict(project.get_source_node(${sourceName}, ${tableName}))`, - )) as ResolveReferenceSourceResult; - // Get columns - if (!node) { - return []; - } - return this.getColumsOfRelation( - node.database, - node.schema, - node.identifier, - ); - } - - private async getColumsOfRelation( - database: string | undefined, - schema: string | undefined, - objectName: string, - ): Promise { - this.throwBridgeErrorIfAvailable(); - return this.python?.lock( - (python) => - python!`to_dict(project.get_columns_in_relation(project.create_relation(${database}, ${schema}, ${objectName})))`, - ); - } - - async getBulkCompiledSQL(models: NodeMetaData[]) { - const result: Record = {}; - for (const m of models) { - try { - const compiledSQL = await this.unsafeCompileNode(m.name); - result[m.uniqueId] = compiledSQL; - } catch (e) { - this.dbtTerminal.error( - "getBulkCompiledSQL", - `Unable to compile sql for model ${m.uniqueId}`, - e, - true, - ); - } - } - return result; - } - - async getBulkSchemaFromDB( - nodes: DBTNode[], - cancellationToken: CancellationToken, - ): Promise> { - if (nodes.length === 0) { - return {}; - } - const result: Record = {}; - for (const n of nodes) { - if (cancellationToken.isCancellationRequested) { - break; - } - if (n.resource_type === DBTProject.RESOURCE_TYPE_SOURCE) { - const source = n as SourceNode; - result[n.unique_id] = await this.getColumnsOfSource( - source.name, - source.table, - ); - } else { - const model = n as Node; - result[n.unique_id] = await this.getColumnsOfModel(model.name); - } - } - return result; - } - - async validateWhetherSqlHasColumns( - sql: string, - dialect: string, - ): Promise { - this.throwBridgeErrorIfAvailable(); - return this.python?.lock( - (python) => - python!`to_dict(validate_whether_sql_has_columns(${sql}, ${dialect}))`, - ); - } - - async fetchSqlglotSchema(sql: string, dialect: string): Promise { - this.throwBridgeErrorIfAvailable(); - return this.python?.lock( - (python) => python!`to_dict(fetch_schema_from_sql(${sql}, ${dialect}))`, - ); - } - - async getCatalog(): Promise { - this.throwBridgeErrorIfAvailable(); - return await this.python?.lock( - (python) => python!`to_dict(project.get_catalog())`, - ); - } - - // get dbt config - private async findModelPaths(): Promise { - return ( - await this.python.lock( - (python) => python`to_dict(project.config.model_paths)`, - ) - ).map((modelPath: string) => { - if (!path.isAbsolute(modelPath)) { - return path.join(this.projectRoot.fsPath, modelPath); - } - return modelPath; - }); - } - - private async findSeedPaths(): Promise { - return ( - await this.python.lock( - (python) => python`to_dict(project.config.seed_paths)`, - ) - ).map((seedPath: string) => { - if (!path.isAbsolute(seedPath)) { - return path.join(this.projectRoot.fsPath, seedPath); - } - return seedPath; - }); - } - - getDebounceForRebuildManifest() { - return 2000; - } - - private async findMacroPaths(): Promise { - return ( - await this.python.lock( - (python) => python`to_dict(project.config.macro_paths)`, - ) - ).map((macroPath: string) => { - if (!path.isAbsolute(macroPath)) { - return path.join(this.projectRoot.fsPath, macroPath); - } - return macroPath; - }); - } - - private async findTargetPath(): Promise { - let targetPath = await this.python.lock( - (python) => python`to_dict(project.config.target_path)`, - ); - if (!path.isAbsolute(targetPath)) { - targetPath = path.join(this.projectRoot.fsPath, targetPath); - } - return targetPath; - } - - private async findPackagesInstallPath(): Promise { - let packageInstallPath = await this.python.lock( - (python) => python`to_dict(project.config.packages_install_path)`, - ); - if (!path.isAbsolute(packageInstallPath)) { - packageInstallPath = path.join( - this.projectRoot.fsPath, - packageInstallPath, - ); - } - return packageInstallPath; - } - - private async findVersion(): Promise { - return this.python?.lock( - (python) => python!`to_dict(project.get_dbt_version())`, - ); - } - - private async findProjectName(): Promise { - return this.python?.lock( - (python) => python!`to_dict(project.config.project_name)`, - ); - } - - protected throwBridgeErrorIfAvailable() { - const allDiagnostics: DiagnosticCollection[] = [ - this.pythonBridgeDiagnostics, - this.projectConfigDiagnostics, - this.rebuildManifestDiagnostics, - ]; - - for (const diagnosticCollection of allDiagnostics) { - for (const [_, diagnostics] of diagnosticCollection) { - const error = diagnostics.find( - (diagnostic) => diagnostic.severity === DiagnosticSeverity.Error, - ); - if (error) { - throw new Error(error.message); - } - } - } - } - - findPackageVersion(packageName: string) { - if (!this.packagesInstallPath) { - throw new Error("Missing packages install path"); - } - if (!packageName) { - throw new Error("Invalid package name"); - } - - const dbtProjectYmlFilePath = path.join( - this.packagesInstallPath, - packageName, - "dbt_project.yml", - ); - if (!existsSync(dbtProjectYmlFilePath)) { - throw new Error("Package not installed"); - } - const fileContents = readFileSync(dbtProjectYmlFilePath, { - encoding: "utf-8", - }); - if (!fileContents) { - throw new Error(`${packageName} has empty dbt_project.yml`); - } - const parsedConfig = parse(fileContents, { - strict: false, - uniqueKeys: false, - maxAliasCount: -1, - }); - if (!parsedConfig?.version) { - throw new Error(`Missing version in ${dbtProjectYmlFilePath}`); - } - - return parsedConfig.version; - } - - async dispose() { - try { - await this.executionInfrastructure.closePythonBridge(this.python); - } catch (error) {} // We don't care about errors here. - this.rebuildManifestDiagnostics.clear(); - this.pythonBridgeDiagnostics.clear(); - while (this.disposables.length) { - const x = this.disposables.pop(); - if (x) { - x.dispose(); - } - } - } - - async performDatapilotHealthcheck({ - manifestPath, - catalogPath, - config, - configPath, - }: HealthcheckArgs): Promise { - this.throwBridgeErrorIfAvailable(); - const healthCheckThread = this.executionInfrastructure.createPythonBridge( - this.projectRoot.fsPath, - ); - try { - await this.createPythonDbtProject(healthCheckThread); - await healthCheckThread.ex`from dbt_healthcheck import *`; - const result = await healthCheckThread.lock( - (python) => - python!`to_dict(project_healthcheck(${manifestPath}, ${catalogPath}, ${configPath}, ${config}, ${this.altimateRequest.getAIKey()}, ${this.altimateRequest.getInstanceName()}, ${AltimateRequest.ALTIMATE_URL}))`, - ); - return result; - } finally { - healthCheckThread.end(); - } - } - - private async getDeferConfig() { - try { - const root = getProjectRelativePath(this.projectRoot); - const currentConfig: Record = - this.deferToProdService.getDeferConfigByWorkspace(); - const { - deferToProduction, - manifestPathForDeferral, - favorState, - manifestPathType, - dbtCoreIntegrationId, - } = currentConfig[root]; - const manifestFolder = await this.getDeferManifestPath( - manifestPathType, - manifestPathForDeferral, - dbtCoreIntegrationId, - ); - const manifestPath = path.join(manifestFolder, DBTProject.MANIFEST_FILE); - return { deferToProduction, manifestPath, favorState }; - } catch (error) { - this.dbtTerminal.debug( - "dbtCoreIntegration:getDeferConfig", - "An error occured while getting defer config: " + - (error as Error).message, - ); - } - return { deferToProduction: false, manifestPath: null, favorState: false }; - } - - async applyDeferConfig(): Promise { - const { deferToProduction, manifestPath, favorState } = - await this.getDeferConfig(); - await this.python?.lock( - (python) => - python!`project.set_defer_config(${deferToProduction}, ${manifestPath}, ${favorState})`, - ); - await this.refreshProjectConfig(); - await this.rebuildManifest(); - } - - async applySelectedTarget(): Promise { - await this.refreshProjectConfig(); - await this.rebuildManifest(); - } - - throwDiagnosticsErrorIfAvailable(): void { - this.throwBridgeErrorIfAvailable(); - } -} diff --git a/src/dbt_client/dbtFusionCommandIntegration.ts b/src/dbt_client/dbtFusionCommandIntegration.ts deleted file mode 100644 index d9353d43a..000000000 --- a/src/dbt_client/dbtFusionCommandIntegration.ts +++ /dev/null @@ -1,504 +0,0 @@ -import { - CancellationToken, - CancellationTokenSource, - Diagnostic, - DiagnosticSeverity, - Range, - Uri, - window, -} from "vscode"; -import { getFirstWorkspacePath, provideSingleton } from "../utils"; -import { - QueryExecution, - DBTCommand, - DBColumn, - Catalog, - DBTNode, - DBTDetection, - DBTProjectDetection, -} from "./dbtIntegration"; -import { - CommandProcessExecutionFactory, - DBTProject, - DBTTerminal, - PythonEnvironment, -} from "../modules"; -import { DBTCloudProjectIntegration, getDBTPath } from "./dbtCloudIntegration"; -import path, { join } from "path"; - -@provideSingleton(DBTFusionCommandDetection) -export class DBTFusionCommandDetection implements DBTDetection { - constructor( - protected commandProcessExecutionFactory: CommandProcessExecutionFactory, - protected pythonEnvironment: PythonEnvironment, - protected terminal: DBTTerminal, - ) {} - - async detectDBT(): Promise { - const dbtPath = getDBTPath(this.pythonEnvironment, this.terminal); - try { - this.terminal.debug("DBTCLIDetection", "Detecting dbt fusion cli"); - const checkDBTInstalledProcess = - this.commandProcessExecutionFactory.createCommandProcessExecution({ - command: dbtPath, - args: ["--version"], - cwd: getFirstWorkspacePath(), - }); - const { stdout, stderr } = await checkDBTInstalledProcess.complete(); - if (stderr) { - throw new Error(stderr); - } - if (stdout.includes("dbt-fusion")) { - this.terminal.debug( - "DBTCLIDetectionSuccess", - "dbt fusion cli detected", - ); - return true; - } else { - this.terminal.debug( - "DBTCLIDetectionFailed", - "dbt fusion cli was not found. Detection command returned : " + - stdout, - ); - } - } catch (error) { - this.terminal.warn( - "DBTCLIDetectionError", - "Detection failed with error : " + (error as Error).message, - ); - } - this.terminal.debug( - "DBTCLIDetectionFailed", - "dbt fusion cli was not found. Detection command returning false", - ); - return false; - } -} - -@provideSingleton(DBTFusionCommandProjectDetection) -export class DBTFusionCommandProjectDetection implements DBTProjectDetection { - async discoverProjects(projectDirectories: Uri[]): Promise { - const packagesInstallPaths = projectDirectories.map((projectDirectory) => - path.join(projectDirectory.fsPath, "dbt_packages"), - ); - const filteredProjectFiles = projectDirectories.filter((uri) => { - return !packagesInstallPaths.some((packageInstallPath) => { - return uri.fsPath.startsWith(packageInstallPath!); - }); - }); - if (filteredProjectFiles.length > 20) { - window.showWarningMessage( - `dbt Power User detected ${filteredProjectFiles.length} projects in your workspace, this will negatively affect performance.`, - ); - } - return filteredProjectFiles; - } -} - -@provideSingleton(DBTFusionCommandProjectIntegration) -export class DBTFusionCommandProjectIntegration extends DBTCloudProjectIntegration { - protected dbtCloudCommand(command: DBTCommand) { - command.setExecutionStrategy( - this.cliDBTCommandExecutionStrategyFactory( - this.projectRoot, - this.dbtPath, - ), - ); - return command; - } - - protected async initializePaths() { - // No way to get these paths from the fusion executable - this.targetPath = join(this.projectRoot.fsPath, "target"); - this.modelPaths = [join(this.projectRoot.fsPath, "models")]; - this.seedPaths = [join(this.projectRoot.fsPath, "seeds")]; - this.macroPaths = [join(this.projectRoot.fsPath, "macros")]; - this.packagesInstallPath = join(this.projectRoot.fsPath, "dbt_packages"); - try { - const projectConfig = DBTProject.readAndParseProjectConfig( - this.projectRoot, - ); - this.projectName = projectConfig.name; - } catch (error) { - this.terminal.warn( - "DbtCloudIntegrationProjectNameFromConfigExceptionError", - "project name could not be read from dbt_project.yml, ignoring", - true, - error, - ); - } - } - - async rebuildManifest(): Promise { - // TODO: check whether we should allow parsing for unauthenticated users - // this.throwIfNotAuthenticated(); - if (this.rebuildManifestCancellationTokenSource) { - this.rebuildManifestCancellationTokenSource.cancel(); - this.rebuildManifestCancellationTokenSource = undefined; - } - const command = this.dbtCloudCommand( - this.dbtCommandFactory.createParseCommand(), - ); - command.addArgument("--log-format"); - command.addArgument("json"); - this.rebuildManifestCancellationTokenSource = new CancellationTokenSource(); - command.setToken(this.rebuildManifestCancellationTokenSource.token); - - try { - const result = await command.execute(); - const stderr = result.stderr; - // sending stderr everytime to verify in logs whether is coming as empty or not. - this.telemetry.sendTelemetryEvent("dbtCloudParseProjectUserError", { - error: stderr, - adapter: this.getAdapterType() || "unknown", - }); - this.terminal.info( - "dbtFusionParseProject", - "dbt fusion response", - false, - { - command: command.getCommandAsString(), - stderr, - }, - ); - const errorsAndWarnings = stderr - .trim() - .split("\n") - .map((line) => line.trim()) - .filter((line) => Boolean(line)) - .map((line) => - this.parseJSON( - "RebuildManifestErrorsAndWarningsJSONParsing", - line, - false, - ), - ); - const errors = errorsAndWarnings - .filter( - (line) => - line && - line.hasOwnProperty("info") && - line.info.hasOwnProperty("level") && - line.info.hasOwnProperty("msg") && - ["error", "fatal"].includes(line.info.level), - ) - .map((line) => line.info.msg); - const warnings = errorsAndWarnings - .filter( - (line) => - line && - line.hasOwnProperty("info") && - line.info.hasOwnProperty("level") && - line.info.hasOwnProperty("msg") && - line.info.level === "warn", - ) - .map((line) => line.info.msg); - this.rebuildManifestDiagnostics.clear(); - const diagnostics: Array = errors - .map( - (error) => - new Diagnostic( - new Range(0, 0, 999, 999), - error, - DiagnosticSeverity.Error, - ), - ) - .concat( - warnings.map( - (warning) => - new Diagnostic( - new Range(0, 0, 999, 999), - warning, - DiagnosticSeverity.Warning, - ), - ), - ); - if (diagnostics) { - // user error - this.rebuildManifestDiagnostics.set( - Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE), - diagnostics, - ); - } - } catch (error) { - this.telemetry.sendTelemetryError( - "dbtCloudCannotParseProjectCommandExecuteError", - error, - { - adapter: this.getAdapterType() || "unknown", - command: command.getCommandAsString(), - }, - ); - this.rebuildManifestDiagnostics.set( - Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE), - [ - new Diagnostic( - new Range(0, 0, 999, 999), - "Unable to parse dbt cloud cli response. If the problem persists please reach out to us: " + - error, - DiagnosticSeverity.Error, - ), - ], - ); - } - } - - async executeSQL( - query: string, - limit: number, - modelName: string, - ): Promise { - const showCommand = this.dbtCloudCommand( - new DBTCommand("Running sql...", [ - "show", - "--log-level", - "debug", - "--inline", - query, - "--limit", - limit.toString(), - "--output", - "json", - "--log-format", - "json", - ]), - ); - const cancellationTokenSource = new CancellationTokenSource(); - showCommand.setToken(cancellationTokenSource.token); - return new QueryExecution( - async () => { - cancellationTokenSource.cancel(); - }, - async () => { - const { stdout, stderr } = await showCommand.execute( - cancellationTokenSource.token, - ); - const exception = this.processJSONErrors(stderr); - if (exception) { - throw exception; - } - const parsedLines = stdout - .trim() - .split("\n") - .map((line) => JSON.parse(line.trim())); - const previewLine = parsedLines.filter( - (line) => - line.hasOwnProperty("data") && line.data.hasOwnProperty("preview"), - ); - if (previewLine.length === 0) { - throw new Error("Could not find previewLine in " + stdout); - } - const compiledSqlLines = parsedLines.filter( - (line) => - line.hasOwnProperty("data") && line.data.hasOwnProperty("sql"), - ); - if (previewLine.length === 0) { - throw new Error("Could not find previewLine in " + stdout); - } - const preview = JSON.parse(previewLine[0].data.preview); - let compiledSql = ""; - // TODO: is there a way to get the last compiled SQL line in fusion? - if (compiledSqlLines.length !== 0) { - compiledSql = compiledSqlLines[compiledSqlLines.length - 1].data.sql; - } - return { - table: { - column_names: preview.length > 0 ? Object.keys(preview[0]) : [], - column_types: - preview.length > 0 - ? Object.keys(preview[0]).map((obj: any) => "string") - : [], - rows: preview.map((obj: any) => Object.values(obj)), - }, - compiled_sql: compiledSql, - raw_sql: query, - modelName, - }; - }, - ); - } - - // internal commands - async unsafeCompileNode(modelName: string): Promise { - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Compiling model...", [ - "compile", - "--model", - modelName, - "--quiet", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - if (stderr) { - throw new Error(stderr); - } - return stdout.trim(); - } - - async unsafeCompileQuery(query: string): Promise { - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Compiling sql...", [ - "compile", - "--inline", - query, - "--quiet", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - if (stderr) { - throw new Error(stderr); - } - return stdout.trim(); - } - - async getColumnsOfSource( - sourceName: string, - tableName: string, - ): Promise { - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Getting columns of source...", [ - "compile", - "--inline", - `{% set output = [] %}{% for result in adapter.get_columns_in_relation(source('${sourceName}', '${tableName}')) %} {% do output.append({"column": result.name, "dtype": result.dtype}) %} {% endfor %} {{ tojson(output) }}`, - "--quiet", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - if (stderr) { - throw new Error(stderr); - } - return JSON.parse(stdout.trim()); - } - - async getColumnsOfModel(modelName: string): Promise { - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Getting columns of model...", [ - "compile", - "--inline", - `{% set output = [] %}{% for result in adapter.get_columns_in_relation(ref('${modelName}')) %} {% do output.append({"column": result.name, "dtype": result.dtype}) %} {% endfor %} {{ tojson(output) }}`, - "--quiet", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - if (stderr) { - throw new Error(stderr); - } - return JSON.parse(stdout.trim()); - } - - async getBulkSchemaFromDB( - nodes: DBTNode[], - cancellationToken: CancellationToken, - ): Promise> { - if (nodes.length === 0) { - return {}; - } - const bulkModelQuery = ` -{% set result = {} %} -{% for n in ${JSON.stringify(nodes)} %} - {% set columns = adapter.get_columns_in_relation(ref(n["name"])) %} - {% set new_columns = [] %} - {% for column in columns %} - {% do new_columns.append({"column": column.name, "dtype": column.dtype}) %} - {% endfor %} - {% do result.update({n["unique_id"]:new_columns}) %} -{% endfor %} -{% for n in graph.sources.values() %} - {% set columns = adapter.get_columns_in_relation(source(n["source_name"], n["identifier"])) %} - {% set new_columns = [] %} - {% for column in columns %} - {% do new_columns.append({"column": column.name, "dtype": column.dtype}) %} - {% endfor %} - {% do result.update({n["unique_id"]:new_columns}) %} -{% endfor %} -{{ tojson(result) }}`; - console.log(bulkModelQuery); - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Getting catalog...", [ - "compile", - "--inline", - bulkModelQuery.trim().split("\n").join(""), - "--quiet", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - if (stderr) { - throw new Error(stderr); - } - return JSON.parse(stdout.trim()); - } - - async getCatalog(): Promise { - const bulkModelQuery = ` -{% set result = [] %} -{% for n in graph.nodes.values() %} - {% if n.resource_type == "test" or - n.resource_type == "analysis" or - n.resource_type == "sql_operation" or - n.config.materialized == "ephemeral" %} - {% continue %} - {% endif %} - {% set columns = adapter.get_columns_in_relation(ref(n["name"])) %} - {% for column in columns %} - {% do result.append({ - "table_database": n.database, - "table_schema": n.schema, - "table_name": n.name, - "column_name": column.name, - "column_type": column.dtype, - }) %} - {% endfor %} -{% endfor %} -{% for n in graph.sources.values() %} - {% set columns = adapter.get_columns_in_relation(source(n["source_name"], n["identifier"])) %} - {% for column in columns %} - {% do result.append({ - "table_database": n.database, - "table_schema": n.schema, - "table_name": n.name, - "column_name": column.name, - "column_type": column.dtype, - }) %} - {% endfor %} -{% endfor %} -{{ tojson(result) }}`; - - const compileQueryCommand = this.dbtCloudCommand( - new DBTCommand("Getting catalog...", [ - "compile", - "--inline", - bulkModelQuery, - "--quiet", - ]), - ); - const { stdout, stderr } = await compileQueryCommand.execute(); - if (!stdout) { - throw new Error("Could not get bulk schema from response: " + stdout); - } - if (stderr) { - throw new Error(stderr); - } - const result: Catalog = JSON.parse(stdout); - return result; - } - - async debug(command: DBTCommand): Promise { - const { stdout, stderr } = await this.dbtCloudCommand(command).execute(); - if (stderr) { - throw new Error(stderr); - } - return stdout; - } - - async generateDocs(_: DBTCommand) { - throw new Error("dbt fusion does not support docs generation"); - } - - async clean(command: DBTCommand): Promise { - const { stdout, stderr } = await this.dbtCloudCommand(command).execute(); - if (stderr) { - throw new Error(stderr); - } - return stdout; - } -} diff --git a/src/dbt_client/dbtIntegration.ts b/src/dbt_client/dbtIntegration.ts deleted file mode 100644 index 32329b37f..000000000 --- a/src/dbt_client/dbtIntegration.ts +++ /dev/null @@ -1,687 +0,0 @@ -import { - CancellationToken, - Diagnostic, - Disposable, - ProgressLocation, - Uri, - window, - workspace, -} from "vscode"; -import { - extendErrorWithSupportLinks, - getFirstWorkspacePath, - provideSingleton, -} from "../utils"; -import { PythonBridge, pythonBridge } from "python-bridge"; -import { provide } from "inversify-binding-decorators"; -import { - CommandProcessExecution, - CommandProcessExecutionFactory, - CommandProcessResult, -} from "../commandProcessExecution"; -import { PythonEnvironment } from "../manifest/pythonEnvironment"; -import { existsSync } from "fs"; -import { TelemetryService } from "../telemetry"; -import { DBTTerminal } from "./dbtTerminal"; -import { - AltimateRequest, - NoCredentialsError, - ValidateSqlParseErrorResponse, -} from "../altimate"; -import { ProjectHealthcheck } from "./dbtCoreIntegration"; -import { NodeMetaData } from "../domain"; - -interface DBTCommandExecution { - command: (token?: CancellationToken) => Promise; - statusMessage: string; - showProgress?: boolean; - focus?: boolean; - token?: CancellationToken; -} - -export interface DBTCommandExecutionStrategy { - execute( - command: DBTCommand, - token?: CancellationToken, - ): Promise; -} - -@provideSingleton(CLIDBTCommandExecutionStrategy) -export class CLIDBTCommandExecutionStrategy - implements DBTCommandExecutionStrategy -{ - constructor( - protected commandProcessExecutionFactory: CommandProcessExecutionFactory, - protected pythonEnvironment: PythonEnvironment, - protected terminal: DBTTerminal, - protected telemetry: TelemetryService, - protected cwd: Uri, - protected dbtPath: string, - ) {} - - async execute( - command: DBTCommand, - token?: CancellationToken, - ): Promise { - const commandExecution = this.executeCommand(command, token); - const executionPromise = command.logToTerminal - ? (await commandExecution).completeWithTerminalOutput() - : (await commandExecution).complete(); - return executionPromise; - } - - protected async executeCommand( - command: DBTCommand, - token?: CancellationToken, - ): Promise { - if (command.logToTerminal && command.focus) { - await this.terminal.show(true); - } - this.telemetry.sendTelemetryEvent("dbtCommand", { - command: command.getCommandAsString(), - }); - if (command.logToTerminal) { - this.terminal.log( - `> Executing task: ${command.getCommandAsString()}\n\r`, - ); - } - const { args } = command!; - if ( - !this.pythonEnvironment.pythonPath || - !this.pythonEnvironment.environmentVariables - ) { - throw Error( - "Could not launch command as python environment is not available", - ); - } - const tokens: CancellationToken[] = []; - if (token !== undefined) { - tokens.push(token); - } - if (command.token !== undefined) { - tokens.push(command.token); - } - return this.commandProcessExecutionFactory.createCommandProcessExecution({ - command: this.dbtPath, - args, - tokens, - cwd: this.cwd.fsPath, - envVars: this.pythonEnvironment.environmentVariables, - }); - } -} - -@provideSingleton(PythonDBTCommandExecutionStrategy) -export class PythonDBTCommandExecutionStrategy - implements DBTCommandExecutionStrategy -{ - constructor( - private commandProcessExecutionFactory: CommandProcessExecutionFactory, - private pythonEnvironment: PythonEnvironment, - private terminal: DBTTerminal, - private telemetry: TelemetryService, - ) {} - - async execute( - command: DBTCommand, - token?: CancellationToken, - ): Promise { - return ( - await this.executeCommand(command, token) - ).completeWithTerminalOutput(); - } - - private async executeCommand( - command: DBTCommand, - token?: CancellationToken, - ): Promise { - this.terminal.log(`> Executing task: ${command.getCommandAsString()}\n\r`); - this.telemetry.sendTelemetryEvent("dbtCommand", { - command: command.getCommandAsString(), - }); - if (command.focus) { - await this.terminal.show(true); - } - - const { args } = command!; - if ( - !this.pythonEnvironment.pythonPath || - !this.pythonEnvironment.environmentVariables - ) { - throw Error( - "Could not launch command as python environment is not available", - ); - } - const tokens: CancellationToken[] = []; - if (token !== undefined) { - tokens.push(token); - } - if (command.token !== undefined) { - tokens.push(command.token); - } - return this.commandProcessExecutionFactory.createCommandProcessExecution({ - command: this.pythonEnvironment.pythonPath, - args: ["-c", this.dbtCommand(args)], - tokens, - cwd: getFirstWorkspacePath(), - envVars: this.pythonEnvironment.environmentVariables, - }); - } - - private dbtCommand(args: string[]): string { - args = args.map((arg) => `r"""${arg.replace(/"/g, '\\"')}"""`); - const dbtCustomRunnerImport = workspace - .getConfiguration("dbt") - .get( - "dbtCustomRunnerImport", - "from dbt.cli.main import dbtRunner", - ); - return `has_dbt_runner = True -try: - ${dbtCustomRunnerImport} -except Exception: - has_dbt_runner = False -if has_dbt_runner: - dbt_cli = dbtRunner() - dbt_cli.invoke([${args}]) -else: - import dbt.main - dbt.main.main([${args}])`; - } -} - -export class DBTCommand { - constructor( - public statusMessage: string, - public args: string[], - public focus: boolean = false, - public showProgress: boolean = false, - public logToTerminal: boolean = false, - public executionStrategy?: DBTCommandExecutionStrategy, - public token?: CancellationToken, - public downloadArtifacts: boolean = false, - ) {} - - addArgument(arg: string) { - this.args.push(arg); - } - - getCommandAsString() { - return "dbt " + this.args.join(" "); - } - - setExecutionStrategy(executionStrategy: DBTCommandExecutionStrategy) { - this.executionStrategy = executionStrategy; - } - - execute(token?: CancellationToken) { - if (this.executionStrategy === undefined) { - throw new Error("Execution strategy is required to run dbt commands"); - } - return this.executionStrategy.execute(this, token); - } - - setToken(token: CancellationToken) { - this.token = token; - } -} - -export interface RunModelParams { - plusOperatorLeft: string; - modelName: string; - plusOperatorRight: string; -} - -export interface ExecuteSQLResult { - table: { - column_names: string[]; - column_types: string[]; - rows: any[][]; - }; - raw_sql: string; - compiled_sql: string; - modelName: string; -} - -export class ExecuteSQLError extends Error { - compiled_sql: string; - constructor(message: string, compiled_sql: string) { - super(message); - this.compiled_sql = compiled_sql; - } -} - -export interface CompilationResult { - compiled_sql: string; -} - -// TODO: standardize error handling -export class DBTIntegrationError extends Error {} -export class DBTIntegrationUnknownError extends Error {} - -export interface DBTDetection { - detectDBT(): Promise; -} - -export interface DBTInstallion { - installDBT(): Promise; -} - -export interface HealthcheckArgs { - manifestPath: string; - catalogPath?: string; - config?: any; - configPath?: string; -} - -export interface DBTProjectDetection { - discoverProjects(projectConfigFiles: Uri[]): Promise; -} - -export class QueryExecution { - constructor( - private cancelFunc: () => Promise, - private queryResult: () => Promise, - ) {} - - cancel(): Promise { - return this.cancelFunc(); - } - - executeQuery(): Promise { - return this.queryResult(); - } -} - -export type DBColumn = { column: string; dtype: string }; - -export type Node = { - unique_id: string; - name: string; - resource_type: string; -}; - -export type SourceNode = { - unique_id: string; - name: string; - resource_type: "source"; - table: string; -}; - -export type DBTNode = Node | SourceNode; - -type CatalogItem = { - table_database: string; - table_schema: string; - table_name: string; - column_name: string; - column_type: string; -}; - -export type Catalog = CatalogItem[]; - -export interface DBTProjectIntegration extends Disposable { - // initialize execution infrastructure - initializeProject(): Promise; - // called when project configuration is changed - refreshProjectConfig(): Promise; - // Change target - setSelectedTarget(targetName: string): Promise; - getTargetNames(): Promise>; - // retrieve dbt configs - getTargetPath(): string | undefined; - getModelPaths(): string[] | undefined; - getSeedPaths(): string[] | undefined; - getMacroPaths(): string[] | undefined; - getPackageInstallPath(): string | undefined; - getAdapterType(): string | undefined; - getVersion(): number[] | undefined; - getProjectName(): string; - getSelectedTarget(): string | undefined; - // parse manifest - rebuildManifest(): Promise; - // execute queries - executeSQL( - query: string, - limit: number, - modelName: string, - ): Promise; - // dbt commands - runModel(command: DBTCommand): Promise; - buildModel(command: DBTCommand): Promise; - buildProject(command: DBTCommand): Promise; - runTest(command: DBTCommand): Promise; - runModelTest(command: DBTCommand): Promise; - compileModel(command: DBTCommand): Promise; - generateDocs(command: DBTCommand): Promise; - clean(command: DBTCommand): Promise; - executeCommandImmediately(command: DBTCommand): Promise; - deps(command: DBTCommand): Promise; - debug(command: DBTCommand): Promise; - // altimate commands - unsafeCompileNode(modelName: string): Promise; - unsafeCompileQuery( - query: string, - originalModelName: string | undefined, - ): Promise; - validateSql( - query: string, - dialect: string, - models: any, // TODO: type this - ): Promise; - validateSQLDryRun(query: string): Promise<{ - bytes_processed: string; // TODO: create type - }>; - getColumnsOfSource( - sourceName: string, - tableName: string, - ): Promise; - getColumnsOfModel(modelName: string): Promise; - getCatalog(): Promise; - getDebounceForRebuildManifest(): number; - getBulkSchemaFromDB( - nodes: DBTNode[], - cancellationToken: CancellationToken, - ): Promise>; - getBulkCompiledSQL(models: NodeMetaData[]): Promise>; - validateWhetherSqlHasColumns(sql: string, dialect: string): Promise; - fetchSqlglotSchema(sql: string, dialect: string): Promise; - findPackageVersion(packageName: string): string | undefined; - performDatapilotHealthcheck( - args: HealthcheckArgs, - ): Promise; - applyDeferConfig(): Promise; - applySelectedTarget(): Promise; - getAllDiagnostic(): Diagnostic[]; - throwDiagnosticsErrorIfAvailable(): void; - getPythonBridgeStatus(): boolean; - cleanupConnections(): Promise; -} - -@provide(DBTCommandExecutionInfrastructure) -export class DBTCommandExecutionInfrastructure { - private queues: Map = new Map< - string, - DBTCommandExecution[] - >(); - private queueStates: Map = new Map(); - - constructor( - private pythonEnvironment: PythonEnvironment, - private telemetry: TelemetryService, - private altimate: AltimateRequest, - private terminal: DBTTerminal, - ) {} - - createPythonBridge(cwd: string): PythonBridge { - let pythonPath = this.pythonEnvironment.pythonPath; - const envVars = this.pythonEnvironment.environmentVariables; - - if (pythonPath.endsWith("python.exe")) { - // replace python.exe with pythonw.exe if path exists - const pythonwPath = pythonPath.replace("python.exe", "pythonw.exe"); - if (existsSync(pythonwPath)) { - this.terminal.debug( - "DBTCommandExecutionInfrastructure", - `Changing python path to ${pythonwPath}`, - ); - pythonPath = pythonwPath; - } - } - this.terminal.debug( - "DBTCommandExecutionInfrastructure", - "Starting python bridge", - { - pythonPath, - cwd, - }, - ); - return pythonBridge({ - python: pythonPath, - cwd: cwd, - env: { - ...envVars, - PYTHONPATH: __dirname, - }, - detached: true, - }); - } - - async closePythonBridge(bridge: PythonBridge) { - this.terminal.debug("dbtIntegration", `Closing python bridge`); - try { - await bridge.disconnect(); - await bridge.end(); - } catch (_) {} - } - - createQueue(queueName: string) { - this.queues.set(queueName, []); - } - - async addCommandToQueue( - queueName: string, - command: DBTCommand, - ): Promise { - this.queues.get(queueName)!.push({ - command: async (token) => { - await command.execute(token); - }, - statusMessage: command.statusMessage, - focus: command.focus, - token: command.token, - showProgress: command.showProgress, - }); - this.pickCommandToRun(queueName); - return undefined; - } - - private async pickCommandToRun(queueName: string): Promise { - const queue = this.queues.get(queueName)!; - const running = this.queueStates.get(queueName); - if (!running && queue.length > 0) { - this.queueStates.set(queueName, true); - const { command, statusMessage, focus, showProgress } = queue.shift()!; - const commandExecution = async (token?: CancellationToken) => { - try { - await command(token); - } catch (error) { - if (error instanceof NoCredentialsError) { - this.altimate.handlePreviewFeatures(); - return; - } - window.showErrorMessage( - extendErrorWithSupportLinks( - `Could not run command '${statusMessage}': ` + error + ".", - ), - ); - this.telemetry.sendTelemetryError("queueRunCommandError", error, { - command: statusMessage, - }); - } - }; - - if (showProgress) { - await window.withProgress( - { - location: focus - ? ProgressLocation.Notification - : ProgressLocation.Window, - cancellable: true, - title: statusMessage, - }, - async (_, token) => { - await commandExecution(token); - }, - ); - } else { - await commandExecution(); - } - this.queueStates.set(queueName, false); - this.pickCommandToRun(queueName); - } - } - - async runCommand(command: DBTCommand) { - const commandExecution: DBTCommandExecution = { - command: async (token) => { - await command.execute(token); - }, - statusMessage: command.statusMessage, - focus: command.focus, - }; - await window.withProgress( - { - location: commandExecution.focus - ? ProgressLocation.Notification - : ProgressLocation.Window, - cancellable: true, - title: commandExecution.statusMessage, - }, - async (_, token) => { - try { - return await commandExecution.command(token); - } catch (error) { - window.showErrorMessage( - extendErrorWithSupportLinks( - `Could not run command '${commandExecution.statusMessage}': ` + - (error as Error).message + - ".", - ), - ); - this.telemetry.sendTelemetryError("runCommandError", error, { - command: commandExecution.statusMessage, - }); - } - }, - ); - } -} - -@provideSingleton(DBTCommandFactory) -export class DBTCommandFactory { - createVersionCommand(): DBTCommand { - return new DBTCommand("Detecting dbt version...", ["--version"]); - } - - createParseCommand(): DBTCommand { - return new DBTCommand("Parsing dbt project...", ["parse"]); - } - - createRunModelCommand(params: RunModelParams): DBTCommand { - const { plusOperatorLeft, modelName, plusOperatorRight } = params; - const buildModelCommandAdditionalParams = workspace - .getConfiguration("dbt") - .get("runModelCommandAdditionalParams", []); - - return new DBTCommand( - "Running dbt model...", - [ - "run", - "--select", - `${plusOperatorLeft}${modelName}${plusOperatorRight}`, - ...buildModelCommandAdditionalParams, - ], - true, - true, - true, - ); - } - - createBuildModelCommand(params: RunModelParams): DBTCommand { - const { plusOperatorLeft, modelName, plusOperatorRight } = params; - const buildModelCommandAdditionalParams = workspace - .getConfiguration("dbt") - .get("buildModelCommandAdditionalParams", []); - - return new DBTCommand( - "Building dbt model...", - [ - "build", - "--select", - `${plusOperatorLeft}${modelName}${plusOperatorRight}`, - ...buildModelCommandAdditionalParams, - ], - true, - true, - true, - ); - } - - createBuildProjectCommand(): DBTCommand { - return new DBTCommand( - "Building dbt project...", - ["build"], - true, - true, - true, - ); - } - - createTestModelCommand(testName: string): DBTCommand { - const testModelCommandAdditionalParams = workspace - .getConfiguration("dbt") - .get("testModelCommandAdditionalParams", []); - - return new DBTCommand( - "Testing dbt model...", - ["test", "--select", testName, ...testModelCommandAdditionalParams], - true, - true, - true, - ); - } - - createCompileModelCommand(params: RunModelParams): DBTCommand { - const { plusOperatorLeft, modelName, plusOperatorRight } = params; - return new DBTCommand( - "Compiling dbt models...", - [ - "compile", - "--select", - `${plusOperatorLeft}${modelName}${plusOperatorRight}`, - ], - true, - true, - true, - ); - } - - createDocsGenerateCommand(): DBTCommand { - return new DBTCommand( - "Generating dbt Docs...", - ["docs", "generate"], - true, - true, - true, - ); - } - - createCleanCommand(): DBTCommand { - return new DBTCommand( - "Cleaning dbt project...", - ["clean"], - true, - true, - true, - ); - } - - createInstallDepsCommand(): DBTCommand { - return new DBTCommand("Installing packages...", ["deps"], true, true, true); - } - - createAddPackagesCommand(packages: string[]): DBTCommand { - return new DBTCommand( - "Installing packages...", - ["deps", "--add-package", ...packages], - true, - true, - true, - ); - } - - createDebugCommand(): DBTCommand { - return new DBTCommand("Debugging...", ["debug"], true, true, true); - } -} diff --git a/src/manifest/dbtProject.ts b/src/dbt_client/dbtProject.ts similarity index 54% rename from src/manifest/dbtProject.ts rename to src/dbt_client/dbtProject.ts index 474ab4a37..b4dbbdd2c 100644 --- a/src/manifest/dbtProject.ts +++ b/src/dbt_client/dbtProject.ts @@ -1,16 +1,53 @@ -import { existsSync, readFileSync, writeFileSync } from "fs"; +import { existsSync, writeFileSync } from "fs"; +import { + Catalog, + CATALOG_FILE, + ColumnMetaData, + DataPilotHealtCheckParams, + DBColumn, + DBTCommand, + DBTCommandExecution, + DBTCommandExecutionInfrastructure, + DBTCommandFactory, + DBTDiagnosticData, + DBTFacade, + DBTNode, + DBTProjectIntegration, + DBTProjectIntegrationAdapter, + DBTProjectIntegrationAdapterEvents, + DBTTerminal, + DBT_PROJECT_FILE, + DeferConfig, + HealthcheckArgs, + isResourceHasDbColumns, + isResourceNode, + MANIFEST_FILE, + NoCredentialsError, + NodeMetaData, + ParsedManifest, + ProjectHealthcheck, + QueryExecution, + QueryExecutionResult, + RESOURCE_TYPE_MODEL, + RESOURCE_TYPE_SOURCE, + RunModelParams, + RunResultsEventData, + SourceNode, + Table, + validateSQLUsingSqlGlot, +} from "@altimateai/dbt-integration"; +import { inject } from "inversify"; import * as path from "path"; import { PythonException } from "python-bridge"; import { - CancellationToken, commands, Diagnostic, DiagnosticCollection, + DiagnosticSeverity, Disposable, Event, EventEmitter, - FileSystemWatcher, languages, ProgressLocation, Range, @@ -20,137 +57,99 @@ import { window, workspace, } from "vscode"; -import { parse, YAMLError } from "yaml"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; +import { AltimateRequest, ModelNode } from "../altimate"; +import { AltimateAuthService } from "../services/altimateAuthService"; +import { SharedStateService } from "../services/sharedStateService"; +import { TelemetryService } from "../telemetry"; +import { TelemetryEvents } from "../telemetry/events"; import { - debounce, extendErrorWithSupportLinks, getColumnNameByCase, - setupWatcherHandler, + getProjectRelativePath, } from "../utils"; +import { ValidationProvider } from "../validation_provider"; +import { DBTProjectLog } from "./dbtProjectLog"; import { ManifestCacheChangedEvent, - RebuildManifestStatusChange, ManifestCacheProjectAddedEvent, + RebuildManifestStatusChange, } from "./event/manifestCacheChangedEvent"; import { ProjectConfigChangedEvent } from "./event/projectConfigChangedEvent"; -import { DBTProjectLog, DBTProjectLogFactory } from "./modules/dbtProjectLog"; -import { - SourceFileWatchers, - SourceFileWatchersFactory, -} from "./modules/sourceFileWatchers"; -import { TargetWatchersFactory } from "./modules/targetWatchers"; -import { PythonEnvironment } from "./pythonEnvironment"; -import { TelemetryService } from "../telemetry"; -import * as crypto from "crypto"; -import { - DBTProjectIntegration, - DBTCommandFactory, - RunModelParams, - Catalog, - DBTNode, - DBColumn, - SourceNode, - HealthcheckArgs, -} from "../dbt_client/dbtIntegration"; -import { DBTCoreProjectIntegration } from "../dbt_client/dbtCoreIntegration"; -import { DBTCloudProjectIntegration } from "../dbt_client/dbtCloudIntegration"; -import { AltimateRequest, NoCredentialsError } from "../altimate"; -import { ValidationProvider } from "../validation_provider"; -import { ModelNode } from "../altimate"; -import { - ColumnMetaData, - GraphMetaMap, - NodeGraphMap, - NodeMetaData, -} from "../domain"; -import { AltimateConfigProps } from "../webview_provider/insightsPanel"; -import { SharedStateService } from "../services/sharedStateService"; -import { TelemetryEvents } from "../telemetry/events"; import { RunResultsEvent } from "./event/runResultsEvent"; -import { DBTCoreCommandProjectIntegration } from "../dbt_client/dbtCoreCommandIntegration"; -import { Table } from "src/services/dbtLineageService"; -import { DBTFusionCommandProjectIntegration } from "src/dbt_client/dbtFusionCommandIntegration"; +import { PythonEnvironment } from "./pythonEnvironment"; interface FileNameTemplateMap { [key: string]: string; } + interface JsonObj { [key: string]: string | number | undefined; } -export class DBTProject implements Disposable { - private _manifestCacheEvent?: ManifestCacheProjectAddedEvent; - - static DBT_PROJECT_FILE = "dbt_project.yml"; - static MANIFEST_FILE = "manifest.json"; - static CATALOG_FILE = "catalog.json"; - - static RESOURCE_TYPE_MODEL = "model"; - static RESOURCE_TYPE_MACRO = "macro"; - static RESOURCE_TYPE_ANALYSIS = "analysis"; - static RESOURCE_TYPE_SOURCE = "source"; - static RESOURCE_TYPE_EXPOSURE = "exposure"; - static RESOURCE_TYPE_SEED = "seed"; - static RESOURCE_TYPE_SNAPSHOT = "snapshot"; - static RESOURCE_TYPE_TEST = "test"; - static RESOURCE_TYPE_METRIC = "semantic_model"; +export class DBTProject implements Disposable, DBTFacade { + private _manifestCacheEvent?: ManifestCacheProjectAddedEvent; readonly projectRoot: Uri; - private projectConfig: any; // TODO: typing - private dbtProjectIntegration: DBTProjectIntegration; + private dbtProjectIntegration: DBTProjectIntegrationAdapter; private _onProjectConfigChanged = new EventEmitter(); public onProjectConfigChanged = this._onProjectConfigChanged.event; private _onRunResults = new EventEmitter(); public onRunResults = this._onRunResults.event; - private sourceFileWatchers: SourceFileWatchers; - public onSourceFileChanged: Event; + private _onSourceFileChanged = new EventEmitter(); + public onSourceFileChanged = this._onSourceFileChanged.event; private dbtProjectLog?: DBTProjectLog; - private disposables: Disposable[] = [this._onProjectConfigChanged]; - private readonly projectConfigDiagnostics = - languages.createDiagnosticCollection("dbt"); public readonly projectHealth = languages.createDiagnosticCollection("dbt"); + public readonly pythonBridgeDiagnostics = + languages.createDiagnosticCollection("dbt-python-bridge"); + public readonly rebuildManifestDiagnostics = + languages.createDiagnosticCollection("dbt-rebuild-manifest"); + public readonly projectConfigDiagnostics = + languages.createDiagnosticCollection("dbt-project-config"); + private disposables: Disposable[] = [ + this._onProjectConfigChanged, + this._onSourceFileChanged, + this.projectHealth, + this.pythonBridgeDiagnostics, + this.rebuildManifestDiagnostics, + this.projectConfigDiagnostics, + ]; private _onRebuildManifestStatusChange = new EventEmitter(); readonly onRebuildManifestStatusChange = this._onRebuildManifestStatusChange.event; private dbSchemaCache: Record = {}; - private depsInitialized = false; + private queues: Map = new Map< + string, + DBTCommandExecution[] + >(); + private queueStates: Map = new Map(); constructor( + @inject(PythonEnvironment) private PythonEnvironment: PythonEnvironment, - private sourceFileWatchersFactory: SourceFileWatchersFactory, - private dbtProjectLogFactory: DBTProjectLogFactory, - private targetWatchersFactory: TargetWatchersFactory, + @inject("Factory") + private dbtProjectLogFactory: ( + onProjectConfigChanged: Event, + ) => DBTProjectLog, private dbtCommandFactory: DBTCommandFactory, private terminal: DBTTerminal, private eventEmitterService: SharedStateService, private telemetry: TelemetryService, - private dbtCoreIntegrationFactory: ( - path: Uri, - projectConfigDiagnostics: DiagnosticCollection, - ) => DBTCoreProjectIntegration, - private dbtCoreCommandIntegrationFactory: ( - path: Uri, - projectConfigDiagnostics: DiagnosticCollection, - ) => DBTCoreCommandProjectIntegration, - private dbtCloudIntegrationFactory: ( - path: Uri, - ) => DBTCloudProjectIntegration, - private dbtFusionCommandIntegrationFactory: ( - path: Uri, - ) => DBTFusionCommandProjectIntegration, + private executionInfrastructure: DBTCommandExecutionInfrastructure, + private dbtIntegrationAdapterFactory: ( + projectRoot: string, + deferConfig: DeferConfig | undefined, + ) => DBTProjectIntegrationAdapter, private altimate: AltimateRequest, private validationProvider: ValidationProvider, + private altimateAuthService: AltimateAuthService, path: Uri, - projectConfig: any, + _projectConfig: any, private _onManifestChanged: EventEmitter, ) { this.projectRoot = path; - this.projectConfig = projectConfig; - try { this.validationProvider.validateCredentialsSilently(); } catch (error) { @@ -162,51 +161,128 @@ export class DBTProject implements Disposable { ); } - this.sourceFileWatchers = - this.sourceFileWatchersFactory.createSourceFileWatchers( - this.onProjectConfigChanged, - ); - this.onSourceFileChanged = this.sourceFileWatchers.onSourceFileChanged; + this.dbtProjectLog = this.dbtProjectLogFactory(this.onProjectConfigChanged); + // Check if dbt loom is installed for telemetry (only for core integration) const dbtIntegrationMode = workspace .getConfiguration("dbt") .get("dbtIntegration", "core"); - switch (dbtIntegrationMode) { - case "cloud": - this.dbtProjectIntegration = this.dbtCloudIntegrationFactory( - this.projectRoot, + if (dbtIntegrationMode === "core") { + this.isDbtLoomInstalled().then((isInstalled) => { + this.telemetry.setTelemetryCustomAttribute( + "dbtLoomInstalled", + `${isInstalled}`, ); - break; - case "fusion": - this.dbtProjectIntegration = this.dbtFusionCommandIntegrationFactory( - this.projectRoot, + }); + } + + // Create the integration adapter which will handle the integration selection internally + this.dbtProjectIntegration = this.dbtIntegrationAdapterFactory( + this.projectRoot.fsPath, + this.retrieveDeferConfigFromSettings(), + ); + + // Set up Node.js watcher events to emit VSCode events directly + this.dbtProjectIntegration.on( + DBTProjectIntegrationAdapterEvents.SOURCE_FILE_CHANGED, + () => { + this.terminal.debug( + "DBTProject", + "Received sourceFileChanged event from Node.js file watchers", ); - break; - case "corecommand": - this.dbtProjectIntegration = this.dbtCoreCommandIntegrationFactory( - this.projectRoot, - this.projectConfigDiagnostics, + this._onSourceFileChanged.fire(); + }, + ); + + this.dbtProjectIntegration.on( + DBTProjectIntegrationAdapterEvents.PROJECT_CONFIG_CHANGED, + () => { + this.terminal.debug( + "DBTProject", + "Received projectConfigChanged event from Node.js project config watcher", ); - break; - default: - this.dbtProjectIntegration = this.dbtCoreIntegrationFactory( - this.projectRoot, - this.projectConfigDiagnostics, + const event = new ProjectConfigChangedEvent(this); + this._onProjectConfigChanged.fire(event); + }, + ); + + this.dbtProjectIntegration.on( + DBTProjectIntegrationAdapterEvents.REBUILD_MANIFEST_STATUS_CHANGE, + (status: { inProgress: boolean }) => { + this.terminal.debug( + "DBTProject", + `Received rebuildManifestStatusChange event: inProgress=${status.inProgress}`, ); - break; - } + const event: RebuildManifestStatusChange = { + project: this, + inProgress: status.inProgress, + }; + this._onRebuildManifestStatusChange.fire(event); + }, + ); + + // Handle manifestCreated events from dbtIntegrationAdapter + this.dbtProjectIntegration.on( + DBTProjectIntegrationAdapterEvents.MANIFEST_PARSED, + (parsedManifest: ParsedManifest) => { + this.terminal.debug( + "DBTProject", + "Received manifestParsed event from dbtIntegrationAdapter", + ); + const manifestCacheEvent: ManifestCacheProjectAddedEvent = { + project: this, + nodeMetaMap: parsedManifest.nodeMetaMap, + macroMetaMap: parsedManifest.macroMetaMap, + metricMetaMap: parsedManifest.metricMetaMap, + sourceMetaMap: parsedManifest.sourceMetaMap, + graphMetaMap: parsedManifest.graphMetaMap, + testMetaMap: parsedManifest.testMetaMap, + docMetaMap: parsedManifest.docMetaMap, + exposureMetaMap: parsedManifest.exposureMetaMap, + modelDepthMap: parsedManifest.modelDepthMap, + }; + this._manifestCacheEvent = manifestCacheEvent; + this._onManifestChanged.fire({ added: [manifestCacheEvent] }); + }, + ); + + // Handle runResultsCreated events from dbtIntegrationAdapter + this.dbtProjectIntegration.on( + DBTProjectIntegrationAdapterEvents.RUN_RESULTS_PARSED, + (runResultsData: RunResultsEventData) => { + this.terminal.debug( + "DBTProject", + "Received runResultsParsed event from dbtIntegrationAdapter", + ); + // Extract unique_ids for cache invalidation + const uniqueIds = runResultsData.results.map( + (result) => result.unique_id, + ); + + // Fire the VSCode event with parsed unique_ids + const runResultsEvent = new RunResultsEvent(this, uniqueIds); + this._onRunResults.fire(runResultsEvent); + }, + ); + + // Handle diagnosticsChanged events from dbtIntegrationAdapter + this.dbtProjectIntegration.on( + DBTProjectIntegrationAdapterEvents.DIAGNOSTICS_CHANGED, + () => { + this.terminal.debug( + "DBTProject", + "Received diagnosticsChanged event from dbtIntegrationAdapter", + ); + this.updateDiagnosticsInProblemsPanel(); + }, + ); this.disposables.push( this.dbtProjectIntegration, - this.targetWatchersFactory.createTargetWatchers( - _onManifestChanged, - this._onRunResults, - this.onProjectConfigChanged, - ), this._onManifestChanged.event((event) => { const addedEvent = event.added?.find( - (e) => e.project.projectRoot.fsPath === this.projectRoot.fsPath, + (e) => e.project.projectRoot === this.projectRoot, ); if (addedEvent) { this._manifestCacheEvent = addedEvent; @@ -215,10 +291,8 @@ export class DBTProject implements Disposable { this.PythonEnvironment.onPythonEnvironmentChanged(() => this.onPythonEnvironmentChanged(), ), - this.sourceFileWatchers, - this.projectConfigDiagnostics, this.onRunResults((event) => { - this.invalidateCacheUsingLastRun(event.file); + this.invalidateCacheUsingUniqueIds(event.uniqueIds || []); }), ); @@ -230,33 +304,36 @@ export class DBTProject implements Disposable { ); } - private async invalidateCacheUsingLastRun(file: Uri) { - const fileContent = readFileSync(file.fsPath, "utf8").toString(); - if (!fileContent) { - return; + private async isDbtLoomInstalled(): Promise { + const dbtLoomThread = this.executionInfrastructure.createPythonBridge( + this.projectRoot.fsPath, + ); + try { + await dbtLoomThread.ex`from dbt_loom import *`; + return true; + } catch (error) { + return false; + } finally { + await this.executionInfrastructure.closePythonBridge(dbtLoomThread); } + } - try { - const runResults = JSON.parse(fileContent); - for (const n of runResults["results"]) { - if (n["unique_id"] in this.dbSchemaCache) { - delete this.dbSchemaCache[n["unique_id"]]; - } + private invalidateCacheUsingUniqueIds(uniqueIds: string[]) { + for (const uniqueId of uniqueIds) { + if (uniqueId in this.dbSchemaCache) { + delete this.dbSchemaCache[uniqueId]; } - } catch (e) { - this.terminal.error( - "invalidateCacheUsingLastRun", - `Unable to parse run_results.json ${e}`, - e, - true, - ); } } - public getProjectName() { + getProjectName() { return this.dbtProjectIntegration.getProjectName(); } + getProjectRoot() { + return this.projectRoot.fsPath; + } + getSelectedTarget() { return this.dbtProjectIntegration.getSelectedTarget(); } @@ -272,15 +349,12 @@ export class DBTProject implements Disposable { title: "Changing target...", cancellable: false, }, - async () => { - await this.dbtProjectIntegration.setSelectedTarget(targetName); - await this.dbtProjectIntegration.applySelectedTarget(); - }, + () => this.dbtProjectIntegration.setSelectedTarget(targetName), ); } getDBTProjectFilePath() { - return path.join(this.projectRoot.fsPath, DBTProject.DBT_PROJECT_FILE); + return path.join(this.projectRoot.fsPath, DBT_PROJECT_FILE); } getTargetPath() { @@ -308,7 +382,7 @@ export class DBTProject implements Disposable { if (!targetPath) { return; } - return path.join(targetPath, DBTProject.MANIFEST_FILE); + return path.join(targetPath, MANIFEST_FILE); } getCatalogPath() { @@ -316,7 +390,7 @@ export class DBTProject implements Disposable { if (!targetPath) { return; } - return path.join(targetPath, DBTProject.CATALOG_FILE); + return path.join(targetPath, CATALOG_FILE); } getPythonBridgeStatus() { @@ -324,10 +398,120 @@ export class DBTProject implements Disposable { } getAllDiagnostic(): Diagnostic[] { - return this.dbtProjectIntegration.getAllDiagnostic(); + const projectURI = Uri.file( + path.join(this.projectRoot.fsPath, DBT_PROJECT_FILE), + ); + const integrationDiagnostics = + this.getCurrentProjectIntegration().getDiagnostics(); + + // Convert diagnostic data to VSCode Diagnostics + const convertedDiagnostics = [ + ...integrationDiagnostics.pythonBridgeDiagnostics.map( + (data) => + new Diagnostic( + new Range( + data.range?.startLine || 0, + data.range?.startColumn || 0, + data.range?.endLine || 999, + data.range?.endColumn || 999, + ), + data.message, + this.mapSeverityToVSCode(data.severity), + ), + ), + ...integrationDiagnostics.rebuildManifestDiagnostics.map( + (data) => + new Diagnostic( + new Range( + data.range?.startLine || 0, + data.range?.startColumn || 0, + data.range?.endLine || 999, + data.range?.endColumn || 999, + ), + data.message, + this.mapSeverityToVSCode(data.severity), + ), + ), + ...(integrationDiagnostics.projectConfigDiagnostics || []).map( + (data) => + new Diagnostic( + new Range( + data.range?.startLine || 0, + data.range?.startColumn || 0, + data.range?.endLine || 999, + data.range?.endColumn || 999, + ), + data.message, + this.mapSeverityToVSCode(data.severity), + ), + ), + ]; + + return [ + ...convertedDiagnostics, + ...(this.projectHealth.get(projectURI) || []), + ]; + } + + private mapSeverityToVSCode(severity: string): DiagnosticSeverity { + switch (severity) { + case "error": + return DiagnosticSeverity.Error; + case "warning": + return DiagnosticSeverity.Warning; + case "info": + return DiagnosticSeverity.Information; + case "hint": + return DiagnosticSeverity.Hint; + default: + return DiagnosticSeverity.Error; + } + } + + private convertDiagnosticDataToVSCode(data: DBTDiagnosticData): Diagnostic { + return new Diagnostic( + new Range( + data.range?.startLine || 0, + data.range?.startColumn || 0, + data.range?.endLine || 999, + data.range?.endColumn || 999, + ), + data.message, + this.mapSeverityToVSCode(data.severity), + ); + } + + updateDiagnosticsInProblemsPanel(): void { + const projectURI = Uri.file( + path.join(this.projectRoot.fsPath, DBT_PROJECT_FILE), + ); + const integrationDiagnostics = + this.getCurrentProjectIntegration().getDiagnostics(); + + // Update each diagnostic collection separately + this.pythonBridgeDiagnostics.set( + projectURI, + integrationDiagnostics.pythonBridgeDiagnostics.map((data) => + this.convertDiagnosticDataToVSCode(data), + ), + ); + + this.rebuildManifestDiagnostics.set( + projectURI, + integrationDiagnostics.rebuildManifestDiagnostics.map((data) => + this.convertDiagnosticDataToVSCode(data), + ), + ); + + this.projectConfigDiagnostics.set( + projectURI, + integrationDiagnostics.projectConfigDiagnostics.map((data) => + this.convertDiagnosticDataToVSCode(data), + ), + ); } - async performDatapilotHealthcheck(args: AltimateConfigProps) { + async performDatapilotHealthcheck(args: DataPilotHealtCheckParams) { const manifestPath = this.getManifestPath(); if (!manifestPath) { throw new Error( @@ -350,7 +534,7 @@ export class DBTProject implements Disposable { docsGenerateCommand.focus = false; docsGenerateCommand.logToTerminal = false; docsGenerateCommand.showProgress = false; - await this.generateDocsImmediately(); + await this.unsafeGenerateDocsImmediately(); healthcheckArgs.catalogPath = this.getCatalogPath(); if (!healthcheckArgs.catalogPath) { throw new Error( @@ -364,10 +548,21 @@ export class DBTProject implements Disposable { "Performing healthcheck", healthcheckArgs, ); - const projectHealthcheck = - await this.dbtProjectIntegration.performDatapilotHealthcheck( - healthcheckArgs, + // Create isolated Python bridge for healthcheck + const healthCheckThread = this.executionInfrastructure.createPythonBridge( + this.projectRoot.fsPath, + ); + + let projectHealthcheck: ProjectHealthcheck; + try { + await healthCheckThread.ex`from dbt_utils import *`; + projectHealthcheck = await healthCheckThread.lock( + (python) => + python!`to_dict(project_healthcheck(${healthcheckArgs.manifestPath}, ${healthcheckArgs.catalogPath}, ${healthcheckArgs.configPath}, ${healthcheckArgs.config}, ${this.altimate.getAIKey()}, ${this.altimate.getInstanceName()}, ${this.altimate.getAltimateUrl()}))`, ); + } finally { + await this.executionInfrastructure.closePythonBridge(healthCheckThread); + } // temp fix: ideally datapilot should return absolute path for (const key in projectHealthcheck.model_insights) { for (const item of projectHealthcheck.model_insights[key]) { @@ -377,39 +572,28 @@ export class DBTProject implements Disposable { return projectHealthcheck; } - async initialize() { - // ensure we watch all files and reflect changes - // This is purely vscode watchers, no need for the project to be fully initialized - const dbtProjectConfigWatcher = workspace.createFileSystemWatcher( - new RelativePattern(this.projectRoot, DBTProject.DBT_PROJECT_FILE), - ); - setupWatcherHandler(dbtProjectConfigWatcher, async () => { - await this.refreshProjectConfig(); - this.rebuildManifest(); - }); - await this.dbtProjectIntegration.initializeProject(); - await this.refreshProjectConfig(); - this.rebuildManifest(); - this.dbtProjectLog = this.dbtProjectLogFactory.createDBTProjectLog( - this.onProjectConfigChanged, - ); + async initialize(): Promise { + // Create command queue for this project + this.createQueue("all"); + + try { + await this.dbtProjectIntegration.initialize(); + } catch (error) { + window.showErrorMessage( + extendErrorWithSupportLinks( + "An unexpected error occured while initializing the dbt project at " + + this.projectRoot + + ": " + + error + + ".", + ), + ); + } // ensure all watchers are cleaned up - this.disposables.push( - this.dbtProjectLog, - dbtProjectConfigWatcher, - this.onSourceFileChanged( - debounce(async () => { - this.terminal.debug( - "DBTProject", - `SourceFileChanged event fired for "${this.getProjectName()}" at ${ - this.projectRoot - }`, - ); - await this.rebuildManifest(); - }, this.dbtProjectIntegration.getDebounceForRebuildManifest()), - ), - ); + if (this.dbtProjectLog) { + this.disposables.push(this.dbtProjectLog); + } this.terminal.debug( "DbtProject", @@ -417,6 +601,10 @@ export class DBTProject implements Disposable { ); } + async rebuildManifest(): Promise { + this.dbtProjectIntegration.rebuildManifest(); + } + private async onPythonEnvironmentChanged() { this.terminal.debug( "DbtProject", @@ -427,100 +615,12 @@ export class DBTProject implements Disposable { await this.initialize(); } - async refreshProjectConfig() { - this.terminal.debug( - "DBTProject", - `Going to refresh the project "${this.getProjectName()}" at ${ - this.projectRoot - } configuration`, - ); - try { - this.projectConfig = DBTProject.readAndParseProjectConfig( - this.projectRoot, - ); - await this.dbtProjectIntegration.refreshProjectConfig(); - this.projectConfigDiagnostics.clear(); - } catch (error) { - if (error instanceof YAMLError) { - this.projectConfigDiagnostics.set( - Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE), - [ - new Diagnostic( - new Range(0, 0, 999, 999), - "dbt_project.yml is invalid : " + error.message, - ), - ], - ); - } else if (error instanceof PythonException) { - this.projectConfigDiagnostics.set( - Uri.joinPath(this.projectRoot, DBTProject.DBT_PROJECT_FILE), - [ - new Diagnostic( - new Range(0, 0, 999, 999), - "dbt configuration is invalid : " + error.exception.message, - ), - ], - ); - } - this.terminal.debug( - "DBTProject", - `An error occurred while trying to refresh the project "${this.getProjectName()}" at ${ - this.projectRoot - } configuration`, - error, - ); - this.telemetry.sendTelemetryError("projectConfigRefreshError", error); - } - const sourcePaths = this.getModelPaths(); - if (sourcePaths === undefined) { - this.terminal.debug( - "DBTProject", - "sourcePaths is not defined in project in " + this.projectRoot.fsPath, - ); - } - const macroPaths = this.getMacroPaths(); - if (macroPaths === undefined) { - this.terminal.debug( - "DBTProject", - "macroPaths is not defined in " + this.projectRoot.fsPath, - ); - } - const seedPaths = this.getSeedPaths(); - if (seedPaths === undefined) { - this.terminal.debug( - "DBTProject", - "macroPaths is not defined in " + this.projectRoot.fsPath, - ); - } - if (sourcePaths && macroPaths && seedPaths) { - const event = new ProjectConfigChangedEvent(this); - this._onProjectConfigChanged.fire(event); - this.terminal.debug( - "DBTProject", - `firing ProjectConfigChanged event for the project "${this.getProjectName()}" at ${ - this.projectRoot - } configuration`, - "targetPaths", - this.getTargetPath(), - "modelPaths", - this.getModelPaths(), - "seedPaths", - this.getSeedPaths(), - "macroPaths", - this.getMacroPaths(), - "packagesInstallPath", - this.getPackageInstallPath(), - "version", - this.getDBTVersion(), - "adapterType", - this.getAdapterType(), - ); - } else { - this.terminal.warn( - "DBTProject", - "Could not send out ProjectConfigChangedEvent because project is not initialized properly. dbt path settings cannot be determined", - ); - } + async refreshProjectConfig(): Promise { + this.dbtProjectIntegration.refreshProjectConfig(); + } + + async parseManifest(): Promise { + return await this.dbtProjectIntegration.parseManifest(); } getAdapterType() { @@ -530,7 +630,7 @@ export class DBTProject implements Disposable { findPackageName(uri: Uri): string | undefined { const documentPath = uri.path; const pathSegments = documentPath - .replace(new RegExp(this.projectRoot.path + "/", "g"), "") + .replace(new RegExp(this.projectRoot + "/", "g"), "") .split("/"); const packagesInstallPath = this.getPackageInstallPath(); if (packagesInstallPath && uri.fsPath.startsWith(packagesInstallPath)) { @@ -546,238 +646,266 @@ export class DBTProject implements Disposable { ); } - private async rebuildManifest() { - this.terminal.debug( - "DBTProject", - `Going to rebuild the manifest for "${this.getProjectName()}" at ${ - this.projectRoot - }`, - ); - this._onRebuildManifestStatusChange.fire({ - project: this, - inProgress: true, - }); - const installDepsOnProjectInitialization = workspace - .getConfiguration("dbt") - .get("installDepsOnProjectInitialization", true); - if (!this.depsInitialized && installDepsOnProjectInitialization) { - try { - await this.installDeps(true); - } catch (error: any) { - // this is best effort - console.warn("An error occured while installing dependencies", error); - } - this.depsInitialized = true; + async runModel(runModelParams: RunModelParams) { + if (!this.validateIntegrationPrerequisites()) { + return undefined; } - await this.dbtProjectIntegration.rebuildManifest(); - this._onRebuildManifestStatusChange.fire({ - project: this, - inProgress: false, - }); - this.terminal.debug( - "DBTProject", - `Finished rebuilding the manifest for "${this.getProjectName()}" at ${ - this.projectRoot - }`, - ); - } - async runModel(runModelParams: RunModelParams) { const runModelCommand = this.dbtCommandFactory.createRunModelCommand(runModelParams); try { - const result = await this.dbtProjectIntegration.runModel(runModelCommand); + const command = + await this.getCurrentProjectIntegration().runModel(runModelCommand); this.telemetry.sendTelemetryEvent("runModel"); - return result; + if (command) { + this.addCommandToQueue("all", command); + } } catch (error) { this.handleNoCredentialsError(error); } } async unsafeRunModelImmediately(runModelParams: RunModelParams) { - const runModelCommand = - this.dbtCommandFactory.createRunModelCommand(runModelParams); - runModelCommand.showProgress = false; - runModelCommand.logToTerminal = false; this.telemetry.sendTelemetryEvent("runModel"); - return this.dbtProjectIntegration.executeCommandImmediately( - runModelCommand, - ); + return this.dbtProjectIntegration.unsafeRunModelImmediately(runModelParams); } async buildModel(runModelParams: RunModelParams) { + if (!this.validateIntegrationPrerequisites()) { + return undefined; + } + const buildModelCommand = this.dbtCommandFactory.createBuildModelCommand(runModelParams); try { - const result = - await this.dbtProjectIntegration.buildModel(buildModelCommand); + const command = + await this.getCurrentProjectIntegration().buildModel(buildModelCommand); this.telemetry.sendTelemetryEvent("buildModel"); - return result; + if (command) { + this.addCommandToQueue("all", command); + } } catch (error) { this.handleNoCredentialsError(error); } } async unsafeBuildModelImmediately(runModelParams: RunModelParams) { - const buildModelCommand = - this.dbtCommandFactory.createBuildModelCommand(runModelParams); - buildModelCommand.showProgress = false; - buildModelCommand.logToTerminal = false; this.telemetry.sendTelemetryEvent("buildModel"); - return this.dbtProjectIntegration.executeCommandImmediately( - buildModelCommand, + return this.dbtProjectIntegration.unsafeBuildModelImmediately( + runModelParams, ); } async buildProject() { + if (!this.validateIntegrationPrerequisites()) { + return; + } + const buildProjectCommand = this.dbtCommandFactory.createBuildProjectCommand(); try { - const result = - await this.dbtProjectIntegration.buildProject(buildProjectCommand); + const command = + await this.getCurrentProjectIntegration().buildProject( + buildProjectCommand, + ); this.telemetry.sendTelemetryEvent("buildProject"); - return result; + if (command) { + this.addCommandToQueue("all", command); + } } catch (error) { this.handleNoCredentialsError(error); } } async unsafeBuildProjectImmediately() { - const buildProjectCommand = - this.dbtCommandFactory.createBuildProjectCommand(); - buildProjectCommand.showProgress = false; - buildProjectCommand.logToTerminal = false; this.telemetry.sendTelemetryEvent("buildProject"); - return this.dbtProjectIntegration.executeCommandImmediately( - buildProjectCommand, - ); + return this.dbtProjectIntegration.unsafeBuildProjectImmediately(); } async runTest(testName: string) { + if (!this.validateIntegrationPrerequisites()) { + return undefined; + } + const testModelCommand = this.dbtCommandFactory.createTestModelCommand(testName); try { - const result = await this.dbtProjectIntegration.runTest(testModelCommand); + const command = + await this.getCurrentProjectIntegration().runTest(testModelCommand); this.telemetry.sendTelemetryEvent("runTest"); - return result; + if (command) { + this.addCommandToQueue("all", command); + } } catch (error) { this.handleNoCredentialsError(error); } } async unsafeRunTestImmediately(testName: string) { - const testModelCommand = - this.dbtCommandFactory.createTestModelCommand(testName); - testModelCommand.showProgress = false; - testModelCommand.logToTerminal = false; this.telemetry.sendTelemetryEvent("runTest"); - return this.dbtProjectIntegration.executeCommandImmediately( - testModelCommand, - ); + return this.dbtProjectIntegration.unsafeRunTestImmediately(testName); } async runModelTest(modelName: string) { + if (!this.validateIntegrationPrerequisites()) { + return undefined; + } + const testModelCommand = this.dbtCommandFactory.createTestModelCommand(modelName); try { - const result = - await this.dbtProjectIntegration.runModelTest(testModelCommand); + const command = + await this.getCurrentProjectIntegration().runModelTest( + testModelCommand, + ); this.telemetry.sendTelemetryEvent("runModelTest"); - return result; + if (command) { + this.addCommandToQueue("all", command); + } } catch (error) { this.handleNoCredentialsError(error); } } async unsafeRunModelTestImmediately(modelName: string) { - const testModelCommand = - this.dbtCommandFactory.createTestModelCommand(modelName); - testModelCommand.showProgress = false; - testModelCommand.logToTerminal = false; this.telemetry.sendTelemetryEvent("runModelTest"); - return this.dbtProjectIntegration.executeCommandImmediately( - testModelCommand, - ); + return this.dbtProjectIntegration.unsafeRunModelTestImmediately(modelName); } private handleNoCredentialsError(error: unknown) { if (error instanceof NoCredentialsError) { - this.altimate.handlePreviewFeatures(); + this.altimateAuthService.handlePreviewFeatures(); return; } window.showErrorMessage((error as Error).message); } - compileModel(runModelParams: RunModelParams) { + private validateIntegrationPrerequisites(): boolean { + // Validate different prerequisites based on integration type + const dbtIntegrationMode = workspace + .getConfiguration("dbt") + .get("dbtIntegration", "core"); + + switch (dbtIntegrationMode) { + case "cloud": + case "fusion": + // For cloud/fusion integrations, validate authentication + try { + this.validationProvider.validateCredentialsSilently(); + return true; + } catch (e) { + window.showErrorMessage((e as Error).message); + return false; + } + case "core": + case "corecommand": + default: + // For core integrations, check if we have a proper dbt installation + // We'll validate through the integration's diagnostic system + const diagnostics = + this.getCurrentProjectIntegration().getDiagnostics(); + const hasErrors = [ + ...diagnostics.pythonBridgeDiagnostics, + ...diagnostics.rebuildManifestDiagnostics, + ].some((diagnostic) => diagnostic.severity === "error"); + + if (hasErrors) { + window.showErrorMessage( + "dbt installation or Python environment is not properly configured", + ); + return false; + } + return true; + } + } + + private requiresAuthentication(): boolean { + const dbtIntegrationMode = workspace + .getConfiguration("dbt") + .get("dbtIntegration", "core"); + return dbtIntegrationMode === "cloud"; + } + + throwIfNotAuthenticated() { + if (this.requiresAuthentication()) { + this.validationProvider.throwIfNotAuthenticated(); + } + } + + async compileModel(runModelParams: RunModelParams) { + if (!this.validateIntegrationPrerequisites()) { + return; + } + const compileModelCommand = this.dbtCommandFactory.createCompileModelCommand(runModelParams); - this.dbtProjectIntegration.compileModel(compileModelCommand); + const command = + await this.getCurrentProjectIntegration().compileModel( + compileModelCommand, + ); this.telemetry.sendTelemetryEvent("compileModel"); + if (command) { + this.addCommandToQueue("all", command); + } } - async generateDocsImmediately(args?: string[]) { - const docsGenerateCommand = - this.dbtCommandFactory.createDocsGenerateCommand(); - args?.forEach((arg) => docsGenerateCommand.addArgument(arg)); - docsGenerateCommand.focus = false; - docsGenerateCommand.logToTerminal = false; - const result = - await this.dbtProjectIntegration.executeCommandImmediately( - docsGenerateCommand, - ); - if (result?.stderr) { - throw new Error(result.stderr); - } + async unsafeCompileModelImmediately(runModelParams: RunModelParams) { + this.telemetry.sendTelemetryEvent("compileModel"); + return this.dbtProjectIntegration.unsafeCompileModelImmediately( + runModelParams, + ); + } + + async unsafeGenerateDocsImmediately(args?: string[]) { + return this.dbtProjectIntegration.unsafeGenerateDocsImmediately(args); } - generateDocs() { + async generateDocs() { + if (!this.validateIntegrationPrerequisites()) { + return; + } + const docsGenerateCommand = this.dbtCommandFactory.createDocsGenerateCommand(); - this.dbtProjectIntegration.generateDocs(docsGenerateCommand); + const command = + await this.getCurrentProjectIntegration().generateDocs( + docsGenerateCommand, + ); this.telemetry.sendTelemetryEvent("generateDocs"); + if (command) { + this.addCommandToQueue("all", command); + } } clean() { - const cleanCommand = this.dbtCommandFactory.createCleanCommand(); + this.throwIfNotAuthenticated(); this.telemetry.sendTelemetryEvent("clean"); - return this.dbtProjectIntegration.clean(cleanCommand); + return this.dbtProjectIntegration.clean(); } - debug() { - const debugCommand = this.dbtCommandFactory.createDebugCommand(); + debug(focus: boolean = true) { this.telemetry.sendTelemetryEvent("debug"); - return this.dbtProjectIntegration.debug(debugCommand); + return this.dbtProjectIntegration.debug(focus); } async installDbtPackages(packages: string[]) { this.telemetry.sendTelemetryEvent("installDbtPackages"); - const installPackagesCommand = - this.dbtCommandFactory.createAddPackagesCommand(packages); - // Add packages first - await this.dbtProjectIntegration.deps(installPackagesCommand); - // Then install - return await this.dbtProjectIntegration.deps( - this.dbtCommandFactory.createInstallDepsCommand(), - ); + return this.dbtProjectIntegration.installDbtPackages(packages); } async installDeps(silent = false) { this.telemetry.sendTelemetryEvent("installDeps"); - const installDepsCommand = - this.dbtCommandFactory.createInstallDepsCommand(); - if (silent) { - installDepsCommand.focus = false; - } - return this.dbtProjectIntegration.deps(installDepsCommand); + return this.dbtProjectIntegration.installDeps(silent); } async compileNode(modelName: string): Promise { this.telemetry.sendTelemetryEvent("compileNode"); + this.throwDiagnosticsErrorIfAvailable(); try { return await this.dbtProjectIntegration.unsafeCompileNode(modelName); } catch (exc: any) { @@ -815,24 +943,32 @@ export class DBTProject implements Disposable { async unsafeCompileNode(modelName: string): Promise { this.telemetry.sendTelemetryEvent("unsafeCompileNode"); - return await this.dbtProjectIntegration.unsafeCompileNode(modelName); + this.throwDiagnosticsErrorIfAvailable(); + this.throwIfNotAuthenticated(); + return this.dbtProjectIntegration.unsafeCompileNode(modelName); } async validateSql(request: { sql: string; dialect: string; models: any[] }) { + this.throwDiagnosticsErrorIfAvailable(); + this.throwIfNotAuthenticated(); + const sqlValidationThread = this.executionInfrastructure.createPythonBridge( + this.projectRoot.fsPath, + ); + const { sql, dialect, models } = request; try { - const { sql, dialect, models } = request; - return this.dbtProjectIntegration.validateSql(sql, dialect, models); - } catch (exc) { - window.showErrorMessage( - extendErrorWithSupportLinks("Could not validate sql." + exc), + return await validateSQLUsingSqlGlot( + sqlValidationThread, + sql, + dialect, + models, ); - this.telemetry.sendTelemetryError("validateSQLError", { - error: exc, - }); + } finally { + await this.executionInfrastructure.closePythonBridge(sqlValidationThread); } } async validateSQLDryRun(query: string) { + this.throwIfNotAuthenticated(); try { return this.dbtProjectIntegration.validateSQLDryRun(query); } catch (exc) { @@ -849,7 +985,7 @@ export class DBTProject implements Disposable { getDBTVersion(): number[] | undefined { // TODO: do this when config or python env changes and cache value try { - return this.dbtProjectIntegration.getVersion(); + return this.getCurrentProjectIntegration().getVersion(); } catch (exc) { window.showErrorMessage( extendErrorWithSupportLinks("Could not get dbt version." + exc), @@ -915,6 +1051,7 @@ export class DBTProject implements Disposable { query: string, originalModelName: string | undefined = undefined, ) { + this.throwIfNotAuthenticated(); return this.dbtProjectIntegration.unsafeCompileQuery( query, originalModelName, @@ -922,18 +1059,20 @@ export class DBTProject implements Disposable { } async getColumnsOfModel(modelName: string) { + this.throwIfNotAuthenticated(); const result = await this.dbtProjectIntegration.getColumnsOfModel(modelName); - await this.dbtProjectIntegration.cleanupConnections(); + await this.getCurrentProjectIntegration().cleanupConnections(); return result; } async getColumnsOfSource(sourceName: string, tableName: string) { + this.throwIfNotAuthenticated(); const result = await this.dbtProjectIntegration.getColumnsOfSource( sourceName, tableName, ); - await this.dbtProjectIntegration.cleanupConnections(); + await this.getCurrentProjectIntegration().cleanupConnections(); return result; } @@ -944,25 +1083,20 @@ export class DBTProject implements Disposable { ); try { + this.throwIfNotAuthenticated(); this.terminal.debug( "getColumnValues", "finding distinct values for column", true, { model, column }, ); - const query = `select ${column} from {{ ref('${model}')}} group by ${column}`; - const queryExecution = await this.dbtProjectIntegration.executeSQL( - query, - 100, // setting this 100 as executeSql needs a limit and distinct values will be usually less in number - model, - ); - const result = await queryExecution.executeQuery(); + const result = this.dbtProjectIntegration.getColumnValues(model, column); this.telemetry.endTelemetryEvent( TelemetryEvents["DocumentationEditor/GetDistinctColumnValues"], undefined, { column, model }, ); - return result.table.rows.flat(); + return (result as any).flat(); } catch (error) { this.telemetry.endTelemetryEvent( TelemetryEvents["DocumentationEditor/GetDistinctColumnValues"], @@ -971,30 +1105,29 @@ export class DBTProject implements Disposable { ); throw error; } finally { - await this.dbtProjectIntegration.cleanupConnections(); + await this.getCurrentProjectIntegration().cleanupConnections(); } } - async getBulkSchemaFromDB( - req: DBTNode[], - cancellationToken: CancellationToken, - ) { + async getBulkSchemaFromDB(req: DBTNode[], signal: AbortSignal) { + this.throwIfNotAuthenticated(); try { - const result = await this.dbtProjectIntegration.getBulkSchemaFromDB( - req, - cancellationToken, - ); - await this.dbtProjectIntegration.cleanupConnections(); + const result = + await this.getCurrentProjectIntegration().getBulkSchemaFromDB( + req, + signal, + ); + await this.getCurrentProjectIntegration().cleanupConnections(); return result; } finally { - await this.dbtProjectIntegration.cleanupConnections(); + await this.getCurrentProjectIntegration().cleanupConnections(); } } async validateWhetherSqlHasColumns(sql: string) { const dialect = this.getAdapterType(); try { - return await this.dbtProjectIntegration.validateWhetherSqlHasColumns( + return await this.getCurrentProjectIntegration().validateWhetherSqlHasColumns( sql, dialect, ); @@ -1007,13 +1140,14 @@ export class DBTProject implements Disposable { ); return false; } finally { - await this.dbtProjectIntegration.cleanupConnections(); + await this.getCurrentProjectIntegration().cleanupConnections(); } } async getCatalog(): Promise { + this.throwIfNotAuthenticated(); try { - const result = await this.dbtProjectIntegration.getCatalog(); + const result = await this.getCurrentProjectIntegration().getCatalog(); return result; } catch (exc: any) { if (exc instanceof PythonException) { @@ -1038,7 +1172,7 @@ export class DBTProject implements Disposable { ); return []; } finally { - await this.dbtProjectIntegration.cleanupConnections(); + await this.getCurrentProjectIntegration().cleanupConnections(); } } @@ -1195,98 +1329,99 @@ export class DBTProject implements Disposable { ); } - async executeSQLWithLimit( + async executeSQLOnQueryPanel(query: string, modelName: string) { + const limit = workspace + .getConfiguration("dbt") + .get("queryLimit", 500); + return this.executeSQLWithLimitOnQueryPanel(query, modelName, limit); + } + + async executeSQLWithLimitOnQueryPanel( query: string, modelName: string, limit: number, - returnImmediately?: boolean, - returnRawResults?: boolean, ) { - // if user added a semicolon at the end, let,s remove it. - query = query.replace(/;\s*$/, ""); - - // Check if query already contains a LIMIT clause and extract it - const limitRegex = /\bLIMIT\s+(\d+)\s*$/i; - const limitMatch = query.match(limitRegex); - - if (limitMatch) { - // Override the limit with the one from the query - const queryLimit = parseInt(limitMatch[1], 10); - if (queryLimit > 0) { - limit = queryLimit; - } - // Remove the LIMIT clause from the query as we'll add it back later - query = query.replace(limitRegex, "").trim(); - } - if (limit <= 0) { window.showErrorMessage("Please enter a positive number for query limit"); return; } - this.telemetry.sendTelemetryEvent("executeSQL", { - adapter: this.getAdapterType(), - limit: limit.toString(), - }); - this.terminal.debug("executeSQL", query, { + this.terminal.info("executeSQL", "Executed query: " + query, true, { adapter: this.getAdapterType(), limit: limit.toString(), }); - - if (returnImmediately) { - const execution = await this.dbtProjectIntegration.executeSQL( - query, - limit, - modelName, - ); - const result = await execution.executeQuery(); - if (returnRawResults) { - return result; - } - const rows: JsonObj[] = []; - // Convert compressed array format to dict[] - for (let i = 0; i < result.table.rows.length; i++) { - result.table.rows[i].forEach((value: any, j: any) => { - rows[i] = { ...rows[i], [result.table.column_names[j]]: value }; - }); - } - const data = { - columnNames: result.table.column_names, - columnTypes: result.table.column_types, - data: rows, - raw_sql: query, - compiled_sql: result.compiled_sql, - }; - - return data; - } this.eventEmitterService.fire({ command: "executeQuery", payload: { query, - fn: this.dbtProjectIntegration.executeSQL(query, limit, modelName), + fn: this.dbtProjectIntegration.executeSQLWithLimit( + query, + modelName, + limit, + ), projectName: this.getProjectName(), }, }); } - executeSQL( + async immediatelyExecuteSQLWithLimit( query: string, modelName: string, - returnImmediately?: boolean, - returnRawResults?: boolean, - ) { - const limit = workspace - .getConfiguration("dbt") - .get("queryLimit", 500); - return this.executeSQLWithLimit( + limit: number, + ): Promise { + this.throwDiagnosticsErrorIfAvailable(); + this.throwIfNotAuthenticated(); + this.terminal.info("executeSQL", "Executed query: " + query, true, { + adapter: this.getAdapterType(), + limit: limit.toString(), + }); + return this.dbtProjectIntegration.immediatelyExecuteSQLWithLimit( query, modelName, limit, - returnImmediately, - returnRawResults, ); } + async executeSQLWithLimit(query: string, modelName: string, limit: number) { + this.throwDiagnosticsErrorIfAvailable(); + this.throwIfNotAuthenticated(); + this.terminal.info("executeSQL", "Executed query: " + query, true, { + adapter: this.getAdapterType(), + limit: limit.toString(), + }); + return this.dbtProjectIntegration.executeSQLWithLimit( + query, + modelName, + limit, + ); + } + + async immediatelyExecuteSQL( + query: string, + modelName: string, + ): Promise { + this.throwDiagnosticsErrorIfAvailable(); + this.throwIfNotAuthenticated(); + const limit = workspace + .getConfiguration("dbt") + .get("queryLimit", 500); + this.terminal.info("executeSQL", "Executed query: " + query, true, { + adapter: this.getAdapterType(), + limit: limit.toString(), + }); + return this.dbtProjectIntegration.immediatelyExecuteSQL(query, modelName); + } + + executeSQL(query: string, modelName: string): Promise { + const limit = workspace + .getConfiguration("dbt") + .get("queryLimit", 500); + this.terminal.info("executeSQL", "Executed query: " + query, true, { + adapter: this.getAdapterType(), + limit: limit.toString(), + }); + return this.dbtProjectIntegration.executeSQL(query, modelName); + } + async dispose() { while (this.disposables.length) { const x = this.disposables.pop(); @@ -1296,23 +1431,6 @@ export class DBTProject implements Disposable { } } - static readAndParseProjectConfig(projectRoot: Uri) { - const dbtProjectConfigLocation = path.join( - projectRoot.fsPath, - DBTProject.DBT_PROJECT_FILE, - ); - const dbtProjectYamlFile = readFileSync(dbtProjectConfigLocation, "utf8"); - return parse(dbtProjectYamlFile, { - strict: false, - uniqueKeys: false, - maxAliasCount: -1, - }); - } - - static hashProjectRoot(projectRoot: string) { - return crypto.createHash("md5").update(projectRoot).digest("hex"); - } - private async findModelInTargetfolder(modelPath: Uri, type: string) { const targetPath = this.getTargetPath(); if (!targetPath) { @@ -1335,209 +1453,24 @@ export class DBTProject implements Disposable { } } - static isResourceNode(resource_type: string): boolean { - return ( - resource_type === DBTProject.RESOURCE_TYPE_MODEL || - resource_type === DBTProject.RESOURCE_TYPE_SEED || - resource_type === DBTProject.RESOURCE_TYPE_ANALYSIS || - resource_type === DBTProject.RESOURCE_TYPE_SNAPSHOT - ); + static isResourceNode(resourceType: string): boolean { + return isResourceNode(resourceType); } - static isResourceHasDbColumns(resource_type: string): boolean { - return ( - resource_type === DBTProject.RESOURCE_TYPE_MODEL || - resource_type === DBTProject.RESOURCE_TYPE_SEED || - resource_type === DBTProject.RESOURCE_TYPE_SNAPSHOT - ); + + static isResourceHasDbColumns(resourceType: string): boolean { + return isResourceHasDbColumns(resourceType); } getNonEphemeralParents(keys: string[]): string[] { - if (!this._manifestCacheEvent) { - throw Error( - "No manifest has been generated. Maybe dbt project has not been parsed yet?", - ); - } - const { nodeMetaMap, graphMetaMap } = this._manifestCacheEvent; - const { parents } = graphMetaMap; - const parentSet = new Set(); - const queue = keys; - const visited: Record = {}; - while (queue.length > 0) { - const curr = queue.shift()!; - if (visited[curr]) { - continue; - } - visited[curr] = true; - const parent = parents.get(curr); - if (!parent) { - continue; - } - for (const n of parent.nodes) { - const splits = n.key.split("."); - const resource_type = splits[0]; - if (resource_type !== DBTProject.RESOURCE_TYPE_MODEL) { - parentSet.add(n.key); - continue; - } - if ( - nodeMetaMap.lookupByUniqueId(n.key)?.config.materialized === - "ephemeral" - ) { - queue.push(n.key); - } else { - parentSet.add(n.key); - } - } - } - return Array.from(parentSet); + return this.dbtProjectIntegration.getNonEphemeralParents(keys); } getChildrenModels({ table }: { table: string }): Table[] { - return this.getConnectedTables("children", table); + return this.dbtProjectIntegration.getChildrenModels({ table }); } getParentModels({ table }: { table: string }): Table[] { - return this.getConnectedTables("parents", table); - } - - private getConnectedTables(key: keyof GraphMetaMap, table: string): Table[] { - const event = this._manifestCacheEvent; - if (!event) { - throw Error( - "No manifest has been generated. Maybe dbt project has not been parsed yet?", - ); - } - const { graphMetaMap, nodeMetaMap } = event; - const node = nodeMetaMap.lookupByBaseName(table); - if (!node) { - throw Error("nodeMetaMap has no entries for " + table); - } - const dependencyNodes = graphMetaMap[key]; - const dependencyNode = dependencyNodes.get(node.uniqueId); - if (!dependencyNode) { - throw Error("graphMetaMap[" + key + "] has no entries for " + table); - } - const tables: Map = new Map(); - dependencyNode.nodes.forEach(({ url, key }) => { - const _node = this.createTable(event, url, key); - if (!_node) { - return; - } - if (!tables.has(_node.table)) { - tables.set(_node.table, _node); - } - }); - return Array.from(tables.values()).sort((a, b) => - a.table.localeCompare(b.table), - ); - } - - private createTable( - event: ManifestCacheProjectAddedEvent, - tableUrl: string | undefined, - key: string, - ): Table | undefined { - const splits = key.split("."); - const nodeType = splits[0]; - const { graphMetaMap, testMetaMap } = event; - const upstreamCount = this.getConnectedNodeCount( - graphMetaMap["children"], - key, - ); - const downstreamCount = this.getConnectedNodeCount( - graphMetaMap["parents"], - key, - ); - if (nodeType === DBTProject.RESOURCE_TYPE_SOURCE) { - const { sourceMetaMap } = event; - const schema = splits[2]; - const table = splits[3]; - const _node = sourceMetaMap.get(schema); - if (!_node) { - return; - } - const _table = _node.tables.find((t) => t.name === table); - if (!_table) { - return; - } - return { - table: key, - label: table, - url: tableUrl, - upstreamCount, - downstreamCount, - nodeType, - isExternalProject: _node.is_external_project, - tests: (graphMetaMap["tests"].get(key)?.nodes || []).map((n) => { - const testKey = n.label.split(".")[0]; - return { ...testMetaMap.get(testKey), key: testKey }; - }), - columns: _table.columns, - description: _table?.description, - packageName: _node.package_name, - }; - } - if (nodeType === DBTProject.RESOURCE_TYPE_METRIC) { - return { - table: key, - label: splits[2], - url: tableUrl, - upstreamCount, - downstreamCount, - nodeType, - materialization: undefined, - tests: [], - columns: {}, - isExternalProject: false, - }; - } - const { nodeMetaMap } = event; - - const table = splits[2]; - if (nodeType === DBTProject.RESOURCE_TYPE_EXPOSURE) { - return { - table: key, - label: table, - url: tableUrl, - upstreamCount, - downstreamCount, - nodeType, - materialization: undefined, - tests: [], - columns: {}, - isExternalProject: false, - }; - } - - const node = nodeMetaMap.lookupByUniqueId(key); - if (!node) { - return; - } - - const materialization = node.config.materialized; - return { - table: key, - label: node.alias, - url: tableUrl, - upstreamCount, - downstreamCount, - isExternalProject: node.is_external_project, - nodeType, - materialization, - description: node.description, - columns: node.columns, - patchPath: node.patch_path, - tests: (graphMetaMap["tests"].get(key)?.nodes || []).map((n) => { - const testKey = n.label.split(".")[0]; - return { ...testMetaMap.get(testKey), key: testKey }; - }), - packageName: node.package_name, - meta: node.meta, - }; - } - - private getConnectedNodeCount(g: NodeGraphMap, key: string) { - return g.get(key)?.nodes.length || 0; + return this.dbtProjectIntegration.getParentModels({ table }); } mergeColumnsFromDB( @@ -1583,7 +1516,8 @@ export class DBTProject implements Disposable { } public findPackageVersion(packageName: string) { - const version = this.dbtProjectIntegration.findPackageVersion(packageName); + const version = + this.getCurrentProjectIntegration().findPackageVersion(packageName); this.terminal.debug( "dbtProject:findPackageVersion", `found ${packageName} version: ${version}`, @@ -1591,32 +1525,31 @@ export class DBTProject implements Disposable { return version; } - async getBulkCompiledSql( - event: ManifestCacheProjectAddedEvent, - models: string[], - ) { + async getBulkCompiledSql(models: string[]) { if (models.length === 0) { return {}; } - const { nodeMetaMap } = event; - return this.dbtProjectIntegration.getBulkCompiledSQL( + if (!this._manifestCacheEvent) { + throw new Error("The dbt manifest is not available"); + } + const { nodeMetaMap } = this._manifestCacheEvent; + return this.getCurrentProjectIntegration().getBulkCompiledSQL( models .map((m) => nodeMetaMap.lookupByUniqueId(m)) .filter(Boolean) as NodeMetaData[], ); } - async getNodesWithDBColumns( - event: ManifestCacheProjectAddedEvent, - modelsToFetch: string[], - cancellationToken: CancellationToken, - ) { + async getNodesWithDBColumns(modelsToFetch: string[], signal: AbortSignal) { const mappedNode: Record = {}; const relationsWithoutColumns: string[] = []; if (modelsToFetch.length === 0) { return { mappedNode, relationsWithoutColumns, mappedCompiledSql: {} }; } - const { nodeMetaMap, sourceMetaMap } = event; + if (!this._manifestCacheEvent) { + throw new Error("The dbt manifest is not available"); + } + const { nodeMetaMap, sourceMetaMap } = this._manifestCacheEvent; const bulkSchemaRequest: DBTNode[] = []; for (const key of modelsToFetch) { @@ -1627,7 +1560,7 @@ export class DBTProject implements Disposable { } const splits = key.split("."); const resource_type = splits[0]; - if (resource_type === DBTProject.RESOURCE_TYPE_SOURCE) { + if (resource_type === RESOURCE_TYPE_SOURCE) { const source = sourceMetaMap.get(splits[2]); const tableName = splits[3]; if (!source) { @@ -1670,20 +1603,19 @@ export class DBTProject implements Disposable { } const dbSchemaRequest = bulkSchemaRequest.filter( - (r) => r.resource_type !== DBTProject.RESOURCE_TYPE_MODEL, + (r) => r.resource_type !== RESOURCE_TYPE_MODEL, ); const sqlglotSchemaRequest = bulkSchemaRequest.filter( - (r) => r.resource_type === DBTProject.RESOURCE_TYPE_MODEL, + (r) => r.resource_type === RESOURCE_TYPE_MODEL, ); let startTime = Date.now(); const sqlglotSchemaResponse = await this.getBulkCompiledSql( - event, sqlglotSchemaRequest.map((r) => r.unique_id), ); const compiledSqlTime = Date.now() - startTime; - if (cancellationToken.isCancellationRequested) { + if (signal.aborted) { return { mappedNode, relationsWithoutColumns, @@ -1703,10 +1635,11 @@ export class DBTProject implements Disposable { } try { - const columns = await this.dbtProjectIntegration.fetchSqlglotSchema( - sqlglotSchemaResponse[r.unique_id], - dialect, - ); + const columns = + await this.getCurrentProjectIntegration().fetchSqlglotSchema( + sqlglotSchemaResponse[r.unique_id], + dialect, + ); sqlglotSchemas[r.unique_id] = columns.map((c) => ({ column: c, dtype: "string", @@ -1723,7 +1656,7 @@ export class DBTProject implements Disposable { } const sqlglotSchemaTime = Date.now() - startTime; - if (cancellationToken.isCancellationRequested) { + if (signal.aborted) { return { mappedNode, relationsWithoutColumns, @@ -1733,9 +1666,9 @@ export class DBTProject implements Disposable { startTime = Date.now(); const dbSchemaResponse = - await this.dbtProjectIntegration.getBulkSchemaFromDB( + await this.getCurrentProjectIntegration().getBulkSchemaFromDB( dbSchemaRequest, - cancellationToken, + signal, ); const dbFetchTime = Date.now() - startTime; @@ -1782,10 +1715,135 @@ export class DBTProject implements Disposable { } async applyDeferConfig(): Promise { - await this.dbtProjectIntegration.applyDeferConfig(); + const deferConfig = this.retrieveDeferConfigFromSettings(); + await this.dbtProjectIntegration.applyDeferConfig(deferConfig); } throwDiagnosticsErrorIfAvailable() { - this.dbtProjectIntegration.throwDiagnosticsErrorIfAvailable(); + // Check integration diagnostics + const integrationDiagnostics = + this.getCurrentProjectIntegration().getDiagnostics(); + const allIntegrationDiagnostics = [ + ...integrationDiagnostics.pythonBridgeDiagnostics, + ...integrationDiagnostics.rebuildManifestDiagnostics, + ]; + + for (const diagnostic of allIntegrationDiagnostics) { + if (diagnostic.severity === "error") { + throw new Error(diagnostic.message); + } + } + + // Check VSCode diagnostic collections + const vscodeCollections: DiagnosticCollection[] = [ + this.projectHealth, + this.pythonBridgeDiagnostics, + this.rebuildManifestDiagnostics, + this.projectConfigDiagnostics, + ]; + + for (const diagnosticCollection of vscodeCollections) { + for (const [_, diagnostics] of diagnosticCollection) { + const error = diagnostics.find( + (diagnostic) => diagnostic.severity === DiagnosticSeverity.Error, + ); + if (error) { + throw new Error(error.message); + } + } + } + } + + private retrieveDeferConfigFromSettings(): DeferConfig | undefined { + const relativePath = getProjectRelativePath(this.projectRoot); + const currentConfig: Record = workspace + .getConfiguration("dbt") + .get("deferConfigPerProject", {}); + if (currentConfig[relativePath]) { + const config = currentConfig[relativePath]; + return new DeferConfig( + config.deferToProduction, + config.favorState, + config.manifestPathForDeferral, + config.manifestPathType, + config.dbtCoreIntegrationId, + ); + } + } + + getDeferConfig(): DeferConfig { + if (!this.dbtProjectIntegration) { + throw new Error("DBT Project Integration is not initialized."); + } + return this.dbtProjectIntegration.getDeferConfig(); + } + + private createQueue(queueName: string) { + this.queues.set(queueName, []); + } + + private addCommandToQueue(queueName: string, command: DBTCommand): void { + this.queues.get(queueName)!.push({ + command: async (signal) => { + await command.execute(signal); + }, + statusMessage: command.statusMessage, + focus: command.focus, + signal: command.signal, + showProgress: command.showProgress, + }); + this.pickCommandToRun(queueName); + } + + private async pickCommandToRun(queueName: string): Promise { + const queue = this.queues.get(queueName)!; + const running = this.queueStates.get(queueName); + if (!running && queue.length > 0) { + this.queueStates.set(queueName, true); + const { command, statusMessage, focus, showProgress } = queue.shift()!; + const commandExecution = async (signal?: AbortSignal) => { + try { + await command(signal); + } catch (error) { + if (error instanceof NoCredentialsError) { + this.altimateAuthService.handlePreviewFeatures(); + return; + } + window.showErrorMessage( + extendErrorWithSupportLinks( + `Could not run command '${statusMessage}': ` + error + ".", + ), + ); + this.telemetry.sendTelemetryError("queueRunCommandError", error, { + command: statusMessage, + }); + } + }; + + if (showProgress) { + await window.withProgress( + { + location: focus + ? ProgressLocation.Notification + : ProgressLocation.Window, + cancellable: true, + title: statusMessage, + }, + async (_, token) => { + const abortController = new AbortController(); + token.onCancellationRequested(() => abortController.abort()); + await commandExecution(abortController.signal); + }, + ); + } else { + await commandExecution(); + } + this.queueStates.set(queueName, false); + this.pickCommandToRun(queueName); + } + } + + private getCurrentProjectIntegration(): DBTProjectIntegration { + return this.dbtProjectIntegration.getCurrentProjectIntegration(); } } diff --git a/src/manifest/dbtProjectContainer.ts b/src/dbt_client/dbtProjectContainer.ts similarity index 97% rename from src/manifest/dbtProjectContainer.ts rename to src/dbt_client/dbtProjectContainer.ts index 848bb0d72..f9102ce5c 100644 --- a/src/manifest/dbtProjectContainer.ts +++ b/src/dbt_client/dbtProjectContainer.ts @@ -1,3 +1,9 @@ +import { + DataPilotHealtCheckParams, + DBTTerminal, + EnvironmentVariables, + RunModelType, +} from "@altimateai/dbt-integration"; import { inject } from "inversify"; import { basename } from "path"; import { @@ -10,19 +16,15 @@ import { workspace, WorkspaceFolder, } from "vscode"; +import { AltimateRequest } from "../altimate"; import { DBTClient } from "../dbt_client"; -import { EnvironmentVariables, RunModelType } from "../domain"; -import { provideSingleton } from "../utils"; +import { AltimateDatapilot } from "../dbt_client/datapilot"; import { DBTProject } from "./dbtProject"; import { DBTWorkspaceFolder } from "./dbtWorkspaceFolder"; import { ManifestCacheChangedEvent, RebuildManifestCombinedStatusChange, } from "./event/manifestCacheChangedEvent"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; -import { AltimateConfigProps } from "../webview_provider/insightsPanel"; -import { AltimateDatapilot } from "../dbt_client/datapilot"; -import { AltimateRequest } from "../altimate"; enum PromptAnswer { YES = "Yes", @@ -37,7 +39,6 @@ export interface ProjectRegisteredUnregisteredEvent { export interface DBTProjectsInitializationEvent {} -@provideSingleton(DBTProjectContainer) export class DBTProjectContainer implements Disposable { public onDBTInstallationVerification = this.dbtClient.onDBTInstallationVerification; @@ -72,6 +73,7 @@ export class DBTProjectContainer implements Disposable { pythonPath?: string, envVars?: EnvironmentVariables, ) => DBTWorkspaceFolder, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, private altimateDatapilot: AltimateDatapilot, private altimate: AltimateRequest, @@ -277,7 +279,7 @@ export class DBTProjectContainer implements Disposable { uri = selectedProject.uri; } } - this.findDBTProject(uri)?.executeSQL(query, modelName); + this.findDBTProject(uri)?.executeSQLOnQueryPanel(query, modelName); } runModel(modelPath: Uri, type?: RunModelType) { @@ -451,7 +453,7 @@ export class DBTProjectContainer implements Disposable { ); } - executeAltimateDatapilotHealthcheck(args: AltimateConfigProps) { + executeAltimateDatapilotHealthcheck(args: DataPilotHealtCheckParams) { const project = this.getProjects().find( (p) => p.projectRoot.fsPath.toString() === args.projectRoot, ); diff --git a/src/manifest/modules/dbtProjectLog.ts b/src/dbt_client/dbtProjectLog.ts similarity index 88% rename from src/manifest/modules/dbtProjectLog.ts rename to src/dbt_client/dbtProjectLog.ts index 76be5e0f7..ef40d582c 100755 --- a/src/manifest/modules/dbtProjectLog.ts +++ b/src/dbt_client/dbtProjectLog.ts @@ -9,17 +9,8 @@ import { window, workspace, } from "vscode"; -import { provideSingleton, setupWatcherHandler, stripANSI } from "../../utils"; -import { ProjectConfigChangedEvent } from "../event/projectConfigChangedEvent"; - -@provideSingleton(DBTProjectLogFactory) -export class DBTProjectLogFactory { - createDBTProjectLog( - onProjectConfigChanged: Event, - ) { - return new DBTProjectLog(onProjectConfigChanged); - } -} +import { setupWatcherHandler, stripANSI } from "../utils"; +import { ProjectConfigChangedEvent } from "./event/projectConfigChangedEvent"; export class DBTProjectLog implements Disposable { private outputChannel?: OutputChannel; diff --git a/src/manifest/dbtWorkspaceFolder.ts b/src/dbt_client/dbtWorkspaceFolder.ts similarity index 91% rename from src/manifest/dbtWorkspaceFolder.ts rename to src/dbt_client/dbtWorkspaceFolder.ts index 98be2a8e6..12e434de0 100755 --- a/src/manifest/dbtWorkspaceFolder.ts +++ b/src/dbt_client/dbtWorkspaceFolder.ts @@ -1,5 +1,5 @@ import { existsSync, statSync } from "fs"; -import { inject, postConstruct } from "inversify"; +import { inject } from "inversify"; import * as path from "path"; import { Diagnostic, @@ -7,24 +7,28 @@ import { EventEmitter, FileSystemWatcher, languages, + Range, RelativePattern, Uri, - Range, window, workspace, WorkspaceFolder, } from "vscode"; +import { YAMLError } from "yaml"; +import { TelemetryService } from "../telemetry"; import { DBTProject } from "./dbtProject"; +import { ProjectRegisteredUnregisteredEvent } from "./dbtProjectContainer"; import { ManifestCacheChangedEvent, RebuildManifestStatusChange, } from "./event/manifestCacheChangedEvent"; -import { TelemetryService } from "../telemetry"; -import { YAMLError } from "yaml"; -import { ProjectRegisteredUnregisteredEvent } from "./dbtProjectContainer"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; -import { DBTProjectDetection } from "src/dbt_client/dbtIntegration"; +import { + DBTProjectDetection, + DBTTerminal, + DBT_PROJECT_FILE, + readAndParseProjectConfig, +} from "@altimateai/dbt-integration"; export class DBTWorkspaceFolder implements Disposable { private watcher: FileSystemWatcher; @@ -36,7 +40,6 @@ export class DBTWorkspaceFolder implements Disposable { new EventEmitter(); readonly onRebuildManifestStatusChange = this._onRebuildManifestStatusChange.event; - private dbtProjectDetection: DBTProjectDetection | undefined; constructor( @inject("DBTProjectFactory") @@ -48,6 +51,7 @@ export class DBTWorkspaceFolder implements Disposable { @inject("Factory") private dbtProjectDetectionFactory: () => DBTProjectDetection, private telemetry: TelemetryService, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, public workspaceFolder: WorkspaceFolder, private _onManifestChanged: EventEmitter, @@ -122,10 +126,7 @@ export class DBTWorkspaceFolder implements Disposable { const dbtProjectFiles = await this.retryWithBackoff( () => workspace.findFiles( - new RelativePattern( - this.workspaceFolder, - `**/${DBTProject.DBT_PROJECT_FILE}`, - ), + new RelativePattern(this.workspaceFolder, `**/${DBT_PROJECT_FILE}`), new RelativePattern(this.workspaceFolder, excludePattern), ), 5, @@ -176,8 +177,14 @@ export class DBTWorkspaceFolder implements Disposable { const filteredProjects = await this.dbtProjectDetectionFactory().discoverProjects( - projectDirectories, + projectDirectories.map((uri) => uri.fsPath), + ); + + if (filteredProjects.length > 20) { + window.showWarningMessage( + `dbt Power User detected ${filteredProjects.length} projects in your workspace, this will negatively affect performance.`, ); + } this.dbtTerminal.info( "discoverProjects", @@ -187,8 +194,8 @@ export class DBTWorkspaceFolder implements Disposable { ); await Promise.all( - filteredProjects.map(async (uri) => { - await this.registerDBTProject(uri); + filteredProjects.map(async (projectPath) => { + await this.registerDBTProject(Uri.file(projectPath)); }), ); } @@ -228,7 +235,7 @@ export class DBTWorkspaceFolder implements Disposable { private async registerDBTProject(uri: Uri) { try { - const projectConfig = DBTProject.readAndParseProjectConfig(uri); + const projectConfig = readAndParseProjectConfig(uri.fsPath); const dbtProject = this.dbtProjectFactory( uri, projectConfig, @@ -259,7 +266,7 @@ export class DBTWorkspaceFolder implements Disposable { ); if (error instanceof YAMLError) { this.projectDiscoveryDiagnostics.set( - Uri.joinPath(uri, DBTProject.DBT_PROJECT_FILE), + Uri.joinPath(uri, DBT_PROJECT_FILE), [new Diagnostic(new Range(0, 0, 999, 999), error.message)], ); } @@ -288,10 +295,7 @@ export class DBTWorkspaceFolder implements Disposable { private createConfigWatcher(): FileSystemWatcher { const watcher = workspace.createFileSystemWatcher( - new RelativePattern( - this.workspaceFolder, - `**/${DBTProject.DBT_PROJECT_FILE}`, - ), + new RelativePattern(this.workspaceFolder, `**/${DBT_PROJECT_FILE}`), ); const dirName = (uri: Uri) => Uri.file(path.dirname(uri.fsPath)); diff --git a/src/manifest/event/manifestCacheChangedEvent.ts b/src/dbt_client/event/manifestCacheChangedEvent.ts similarity index 96% rename from src/manifest/event/manifestCacheChangedEvent.ts rename to src/dbt_client/event/manifestCacheChangedEvent.ts index 060664d15..6a570ba9f 100755 --- a/src/manifest/event/manifestCacheChangedEvent.ts +++ b/src/dbt_client/event/manifestCacheChangedEvent.ts @@ -1,4 +1,3 @@ -import { Uri } from "vscode"; import { DocMetaMap, ExposureMetaMap, @@ -8,7 +7,8 @@ import { NodeMetaMap, SourceMetaMap, TestMetaMap, -} from "../../domain"; +} from "@altimateai/dbt-integration"; +import { Uri } from "vscode"; import { DBTProject } from "../dbtProject"; export interface ManifestCacheProjectAddedEvent { diff --git a/src/manifest/event/projectConfigChangedEvent.ts b/src/dbt_client/event/projectConfigChangedEvent.ts similarity index 100% rename from src/manifest/event/projectConfigChangedEvent.ts rename to src/dbt_client/event/projectConfigChangedEvent.ts diff --git a/src/manifest/event/runResultsEvent.ts b/src/dbt_client/event/runResultsEvent.ts similarity index 71% rename from src/manifest/event/runResultsEvent.ts rename to src/dbt_client/event/runResultsEvent.ts index 9b6f437da..bd797202f 100644 --- a/src/manifest/event/runResultsEvent.ts +++ b/src/dbt_client/event/runResultsEvent.ts @@ -1,9 +1,8 @@ -import { Uri } from "vscode"; import { DBTProject } from "../dbtProject"; export class RunResultsEvent { constructor( public project: DBTProject, - public file: Uri, + public uniqueIds?: string[], ) {} } diff --git a/src/dbt_client/index.ts b/src/dbt_client/index.ts index b4327dcdb..004b5d80c 100644 --- a/src/dbt_client/index.ts +++ b/src/dbt_client/index.ts @@ -1,16 +1,14 @@ -import { commands, Disposable, EventEmitter, window, workspace } from "vscode"; -import { PythonEnvironment } from "../manifest/pythonEnvironment"; -import { provideSingleton } from "../utils"; -import { DBTInstallationVerificationEvent } from "./dbtVersionEvent"; +import { DBTDetection } from "@altimateai/dbt-integration"; import { existsSync } from "fs"; -import { DBTDetection } from "./dbtIntegration"; import { inject } from "inversify"; +import { commands, Disposable, EventEmitter, window, workspace } from "vscode"; +import { DBTInstallationVerificationEvent } from "./dbtVersionEvent"; +import { PythonEnvironment } from "./pythonEnvironment"; enum PythonInterpreterPromptAnswer { SELECT = "Select Python interpreter", } -@provideSingleton(DBTClient) export class DBTClient implements Disposable { private _onDBTInstallationVerificationEvent = new EventEmitter(); @@ -29,6 +27,7 @@ export class DBTClient implements Disposable { ]; private shownError = false; constructor( + @inject(PythonEnvironment) private pythonEnvironment: PythonEnvironment, @inject("Factory") private dbtDetectionFactory: () => DBTDetection, diff --git a/src/manifest/pythonEnvironment.ts b/src/dbt_client/pythonEnvironment.ts old mode 100755 new mode 100644 similarity index 94% rename from src/manifest/pythonEnvironment.ts rename to src/dbt_client/pythonEnvironment.ts index e1af8c6b7..d1c2aa50b --- a/src/manifest/pythonEnvironment.ts +++ b/src/dbt_client/pythonEnvironment.ts @@ -1,9 +1,6 @@ +import { DBTTerminal, EnvironmentVariables } from "@altimateai/dbt-integration"; +import { inject } from "inversify"; import { Disposable, Event, extensions, Uri, workspace } from "vscode"; -import { EnvironmentVariables } from "../domain"; -import { provideSingleton } from "../utils"; -import { TelemetryService } from "../telemetry"; -import { CommandProcessExecutionFactory } from "../commandProcessExecution"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; type EnvFrom = "process" | "integrated" | "dotenv"; interface PythonExecutionDetails { @@ -12,20 +9,19 @@ interface PythonExecutionDetails { getEnvVars: () => EnvironmentVariables; } -@provideSingleton(PythonEnvironment) -export class PythonEnvironment implements Disposable { +export class PythonEnvironment { private executionDetails?: PythonExecutionDetails; private disposables: Disposable[] = []; private environmentVariableSource: Record = {}; public allPythonPaths: { path: string; pathType: string }[] = []; public isPython3: boolean = true; + constructor( - private telemetry: TelemetryService, - private commandProcessExecutionFactory: CommandProcessExecutionFactory, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, ) {} - dispose() { + async dispose(): Promise { while (this.disposables.length) { const x = this.disposables.pop(); if (x) { diff --git a/src/dbt_client/runtimePythonEnvironmentProvider.ts b/src/dbt_client/runtimePythonEnvironmentProvider.ts new file mode 100644 index 000000000..25a962661 --- /dev/null +++ b/src/dbt_client/runtimePythonEnvironmentProvider.ts @@ -0,0 +1,66 @@ +import { + PythonEnvironmentProvider, + RuntimePythonEnvironment, +} from "@altimateai/dbt-integration"; +import { inject, injectable } from "inversify"; +import { PythonEnvironment } from "./pythonEnvironment"; + +@injectable() +export class VSCodeRuntimePythonEnvironmentProvider + implements PythonEnvironmentProvider +{ + private callbacks: ((environment: RuntimePythonEnvironment) => void)[] = []; + + constructor( + @inject(PythonEnvironment) + private vscodeEnvironment: PythonEnvironment, + ) { + // Set up python environment change handling + // Initialize environment and listen for changes + this.vscodeEnvironment.initialize().then(() => { + this.vscodeEnvironment.onPythonEnvironmentChanged(() => { + const currentEnvironment = this.getCurrentEnvironment(); + this.callbacks.forEach((callback) => callback(currentEnvironment)); + }); + }); + } + + getCurrentEnvironment(): RuntimePythonEnvironment { + return { + pythonPath: this.vscodeEnvironment.pythonPath, + environmentVariables: this.vscodeEnvironment.environmentVariables, + }; + } + + onEnvironmentChanged( + callback: (environment: RuntimePythonEnvironment) => void, + ): () => void { + this.callbacks.push(callback); + + // Return cleanup function + return () => { + const index = this.callbacks.indexOf(callback); + if (index > -1) { + this.callbacks.splice(index, 1); + } + }; + } +} + +@injectable() +export class StaticRuntimePythonEnvironment + implements RuntimePythonEnvironment +{ + constructor( + @inject(PythonEnvironment) + private vscodeEnvironment: PythonEnvironment, + ) {} + + get pythonPath(): string { + return this.vscodeEnvironment.pythonPath; + } + + get environmentVariables() { + return this.vscodeEnvironment.environmentVariables; + } +} diff --git a/src/dbt_client/vscodeConfiguration.ts b/src/dbt_client/vscodeConfiguration.ts new file mode 100644 index 000000000..df3f4e31a --- /dev/null +++ b/src/dbt_client/vscodeConfiguration.ts @@ -0,0 +1,126 @@ +import { + DBTConfiguration, + DEFAULT_CONFIGURATION_VALUES, +} from "@altimateai/dbt-integration"; +import { injectable } from "inversify"; +import { workspace } from "vscode"; +import { getFirstWorkspacePath } from "../utils"; + +@injectable() +export class VSCodeDBTConfiguration implements DBTConfiguration { + getDbtCustomRunnerImport(): string { + return workspace + .getConfiguration("dbt") + .get( + "dbtCustomRunnerImport", + DEFAULT_CONFIGURATION_VALUES.dbtCustomRunnerImport, + ); + } + + getDbtIntegration(): string { + return workspace + .getConfiguration("dbt") + .get( + "dbtIntegration", + DEFAULT_CONFIGURATION_VALUES.dbtIntegration, + ); + } + + getRunModelCommandAdditionalParams(): string[] { + return workspace + .getConfiguration("dbt") + .get< + string[] + >("runModelCommandAdditionalParams", DEFAULT_CONFIGURATION_VALUES.runModelCommandAdditionalParams); + } + + getBuildModelCommandAdditionalParams(): string[] { + return workspace + .getConfiguration("dbt") + .get< + string[] + >("buildModelCommandAdditionalParams", DEFAULT_CONFIGURATION_VALUES.buildModelCommandAdditionalParams); + } + + getTestModelCommandAdditionalParams(): string[] { + return workspace + .getConfiguration("dbt") + .get< + string[] + >("testModelCommandAdditionalParams", DEFAULT_CONFIGURATION_VALUES.testModelCommandAdditionalParams); + } + + getQueryTemplate(): string { + return workspace + .getConfiguration("dbt") + .get("queryTemplate", DEFAULT_CONFIGURATION_VALUES.queryTemplate); + } + + getQueryLimit(): number { + return workspace + .getConfiguration("dbt") + .get("queryLimit", DEFAULT_CONFIGURATION_VALUES.queryLimit); + } + + getEnableNotebooks(): boolean { + return workspace + .getConfiguration("dbt") + .get( + "enableNotebooks", + DEFAULT_CONFIGURATION_VALUES.enableNotebooks, + ); + } + + getDisableQueryHistory(): boolean { + return workspace + .getConfiguration("dbt") + .get( + "disableQueryHistory", + DEFAULT_CONFIGURATION_VALUES.disableQueryHistory, + ); + } + + getInstallDepsOnProjectInitialization(): boolean { + return workspace + .getConfiguration("dbt") + .get( + "installDepsOnProjectInitialization", + DEFAULT_CONFIGURATION_VALUES.installDepsOnProjectInitialization, + ); + } + + getDisableDepthsCalculation(): boolean { + return workspace + .getConfiguration("dbt") + .get( + "disableDepthsCalculation", + DEFAULT_CONFIGURATION_VALUES.disableDepthsCalculation, + ); + } + + getWorkingDirectory(): string { + return getFirstWorkspacePath(); + } + + getAltimateUrl(): string { + return workspace + .getConfiguration("dbt") + .get("altimateUrl", DEFAULT_CONFIGURATION_VALUES.altimateUrl); + } + + getIsLocalMode(): boolean { + return workspace + .getConfiguration("dbt") + .get("isLocalMode", DEFAULT_CONFIGURATION_VALUES.isLocalMode); + } + + getAltimateInstanceName(): string | undefined { + return workspace + .getConfiguration("dbt") + .get("altimateInstanceName"); + } + + getAltimateAiKey(): string | undefined { + return workspace.getConfiguration("dbt").get("altimateAiKey"); + } +} diff --git a/src/dbt_client/dbtTerminal.ts b/src/dbt_client/vscodeTerminal.ts similarity index 78% rename from src/dbt_client/dbtTerminal.ts rename to src/dbt_client/vscodeTerminal.ts index 8197187a8..c8ab6b6d4 100644 --- a/src/dbt_client/dbtTerminal.ts +++ b/src/dbt_client/vscodeTerminal.ts @@ -1,10 +1,12 @@ +import { DBTTerminal } from "@altimateai/dbt-integration"; +import { injectable } from "inversify"; +import { PythonException } from "python-bridge"; import { Disposable, EventEmitter, Terminal, window } from "vscode"; -import { provideSingleton, stripANSI } from "../utils"; import { TelemetryService } from "../telemetry"; -import { PythonException } from "python-bridge"; +import { stripANSI } from "../utils"; -@provideSingleton(DBTTerminal) -export class DBTTerminal { +@injectable() +export class VSCodeDBTTerminal implements DBTTerminal { private disposables: Disposable[] = []; private terminal?: Terminal; private readonly writeEmitter = new EventEmitter(); @@ -21,43 +23,6 @@ export class DBTTerminal { } } - logNewLine() { - this.log("\r\n"); - } - - logLine(line: string) { - this.log(line); - this.logNewLine(); - } - - logHorizontalRule() { - this.logLine( - "--------------------------------------------------------------------------", - ); - } - - logBlock(block: string[]) { - this.logHorizontalRule(); - for (const line of block) { - if (line) { - this.logLine(line); - } - } - this.logHorizontalRule(); - } - - logBlockWithHeader(header: string[], block: string[]) { - this.logHorizontalRule(); - for (const line of header) { - this.logLine(line); - } - this.logHorizontalRule(); - for (const line of block) { - this.logLine(line); - } - this.logHorizontalRule(); - } - log(message: string, ...args: any[]) { this.outputChannel.info(stripANSI(message), args); console.log(stripANSI(message), args); diff --git a/src/definition_provider/docDefinitionProvider.ts b/src/definition_provider/docDefinitionProvider.ts index cfce0d7f5..3feb0d5ff 100755 --- a/src/definition_provider/docDefinitionProvider.ts +++ b/src/definition_provider/docDefinitionProvider.ts @@ -1,3 +1,4 @@ +import { DocMetaMap } from "@altimateai/dbt-integration"; import { Definition, DefinitionLink, @@ -9,13 +10,10 @@ import { TextDocument, Uri, } from "vscode"; -import { DocMetaMap } from "../domain"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ManifestCacheChangedEvent } from "../manifest/event/manifestCacheChangedEvent"; -import { provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { ManifestCacheChangedEvent } from "../dbt_client/event/manifestCacheChangedEvent"; import { TelemetryService } from "../telemetry"; -@provideSingleton(DocDefinitionProvider) export class DocDefinitionProvider implements DefinitionProvider, Disposable { private docToLocationMap: Map = new Map(); private static readonly IS_DOC = /(doc)\([^)]*\)/; diff --git a/src/definition_provider/index.ts b/src/definition_provider/index.ts index cf5e07b6f..1f01cc8e2 100755 --- a/src/definition_provider/index.ts +++ b/src/definition_provider/index.ts @@ -1,12 +1,10 @@ import { Disposable, languages } from "vscode"; import { DBTPowerUserExtension } from "../dbtPowerUserExtension"; -import { provideSingleton } from "../utils"; import { DocDefinitionProvider } from "./docDefinitionProvider"; import { MacroDefinitionProvider } from "./macroDefinitionProvider"; import { ModelDefinitionProvider } from "./modelDefinitionProvider"; import { SourceDefinitionProvider } from "./sourceDefinitionProvider"; -@provideSingleton(DefinitionProviders) export class DefinitionProviders implements Disposable { private disposables: Disposable[] = []; diff --git a/src/definition_provider/macroDefinitionProvider.ts b/src/definition_provider/macroDefinitionProvider.ts index 0b011cc2c..16ee50a56 100755 --- a/src/definition_provider/macroDefinitionProvider.ts +++ b/src/definition_provider/macroDefinitionProvider.ts @@ -1,3 +1,4 @@ +import { MacroMetaMap } from "@altimateai/dbt-integration"; import { Definition, DefinitionLink, @@ -9,12 +10,10 @@ import { TextDocument, Uri, } from "vscode"; -import { MacroMetaMap } from "../domain"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ManifestCacheChangedEvent } from "../manifest/event/manifestCacheChangedEvent"; -import { isEnclosedWithinCodeBlock, provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { ManifestCacheChangedEvent } from "../dbt_client/event/manifestCacheChangedEvent"; import { TelemetryService } from "../telemetry"; -@provideSingleton(MacroDefinitionProvider) +import { isEnclosedWithinCodeBlock } from "../utils"; export class MacroDefinitionProvider implements DefinitionProvider, Disposable { private macroToLocationMap: Map = new Map(); private static readonly IS_MACRO = /\w+\.?\w+/; diff --git a/src/definition_provider/modelDefinitionProvider.ts b/src/definition_provider/modelDefinitionProvider.ts index 744110151..a921305a6 100755 --- a/src/definition_provider/modelDefinitionProvider.ts +++ b/src/definition_provider/modelDefinitionProvider.ts @@ -1,3 +1,5 @@ +import { DBTTerminal, NodeMetaMap } from "@altimateai/dbt-integration"; +import { inject } from "inversify"; import { CancellationToken, Definition, @@ -11,14 +13,10 @@ import { TextDocument, Uri, } from "vscode"; -import { NodeMetaMap } from "../domain"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ManifestCacheChangedEvent } from "../manifest/event/manifestCacheChangedEvent"; -import { provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { ManifestCacheChangedEvent } from "../dbt_client/event/manifestCacheChangedEvent"; import { TelemetryService } from "../telemetry"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; -@provideSingleton(ModelDefinitionProvider) export class ModelDefinitionProvider implements DefinitionProvider, Disposable { private modelToLocationMap: Map = new Map(); private static readonly IS_REF = /(ref)\([^)]*\)/; @@ -28,6 +26,7 @@ export class ModelDefinitionProvider implements DefinitionProvider, Disposable { constructor( private dbtProjectContainer: DBTProjectContainer, private telemetry: TelemetryService, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, ) { this.disposables.push( diff --git a/src/definition_provider/sourceDefinitionProvider.ts b/src/definition_provider/sourceDefinitionProvider.ts index ed4a5ac8b..daccf29d9 100755 --- a/src/definition_provider/sourceDefinitionProvider.ts +++ b/src/definition_provider/sourceDefinitionProvider.ts @@ -1,3 +1,4 @@ +import { SourceMetaMap } from "@altimateai/dbt-integration"; import { readFileSync } from "fs"; import { CancellationToken, @@ -11,13 +12,11 @@ import { TextDocument, Uri, } from "vscode"; -import { SourceMetaMap } from "../domain"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ManifestCacheChangedEvent } from "../manifest/event/manifestCacheChangedEvent"; -import { isEnclosedWithinCodeBlock, provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { ManifestCacheChangedEvent } from "../dbt_client/event/manifestCacheChangedEvent"; import { TelemetryService } from "../telemetry"; +import { isEnclosedWithinCodeBlock } from "../utils"; -@provideSingleton(SourceDefinitionProvider) export class SourceDefinitionProvider implements DefinitionProvider, Disposable { diff --git a/src/document_formatting_edit_provider/dbtDocumentFormattingEditProvider.ts b/src/document_formatting_edit_provider/dbtDocumentFormattingEditProvider.ts index cf8000506..cd585c512 100644 --- a/src/document_formatting_edit_provider/dbtDocumentFormattingEditProvider.ts +++ b/src/document_formatting_edit_provider/dbtDocumentFormattingEditProvider.ts @@ -1,4 +1,8 @@ +import { CommandProcessExecutionFactory } from "@altimateai/dbt-integration"; +import fs from "fs"; +import { inject } from "inversify"; import parseDiff from "parse-diff"; +import path from "path"; import { CancellationToken, DocumentFormattingEditProvider, @@ -10,24 +14,17 @@ import { workspace, } from "vscode"; import which from "which"; -import { CommandProcessExecutionFactory } from "../commandProcessExecution"; -import { - extendErrorWithSupportLinks, - getFirstWorkspacePath, - provideSingleton, -} from "../utils"; +import { PythonEnvironment } from "../dbt_client/pythonEnvironment"; import { TelemetryService } from "../telemetry"; -import { PythonEnvironment } from "../manifest/pythonEnvironment"; -import path from "path"; -import fs from "fs"; +import { extendErrorWithSupportLinks, getFirstWorkspacePath } from "../utils"; -@provideSingleton(DbtDocumentFormattingEditProvider) export class DbtDocumentFormattingEditProvider implements DocumentFormattingEditProvider { constructor( private commandProcessExecutionFactory: CommandProcessExecutionFactory, private telemetry: TelemetryService, + @inject(PythonEnvironment) private pythonEnvironment: PythonEnvironment, ) {} @@ -59,6 +56,9 @@ export class DbtDocumentFormattingEditProvider try { // try to find sqlfmt on PATH if not set const sqlFmtPath = sqlFmtPathSetting || (await this.findSqlFmtPath()); + if (!sqlFmtPath) { + throw new Error("sqlfmt not found"); + } this.telemetry.sendTelemetryEvent("formatDbtModel", { sqlFmtPath: sqlFmtPathSetting ? "setting" : "path", }); diff --git a/src/document_formatting_edit_provider/index.ts b/src/document_formatting_edit_provider/index.ts index f290a2019..d304437f0 100644 --- a/src/document_formatting_edit_provider/index.ts +++ b/src/document_formatting_edit_provider/index.ts @@ -1,9 +1,7 @@ import { Disposable, languages } from "vscode"; import { DBTPowerUserExtension } from "../dbtPowerUserExtension"; -import { provideSingleton } from "../utils"; import { DbtDocumentFormattingEditProvider } from "./dbtDocumentFormattingEditProvider"; -@provideSingleton(DocumentFormattingEditProviders) export class DocumentFormattingEditProviders implements Disposable { private disposables: Disposable[] = []; diff --git a/src/domain.ts b/src/domain.ts deleted file mode 100755 index 63c442279..000000000 --- a/src/domain.ts +++ /dev/null @@ -1,236 +0,0 @@ -import * as path from "path"; - -export type MacroMetaMap = Map; -export type MetricMetaMap = Map; -export type SourceMetaMap = Map; -export type TestMetaMap = Map; -export type ExposureMetaMap = Map; -export type DocMetaMap = Map; -export type NodeMetaType = NodeMetaData; -export type SourceMetaType = SourceTable; - -export interface NodeMetaMap { - lookupByBaseName(modelBaseName: string): NodeMetaData | undefined; - lookupByUniqueId(uniqueId: string): NodeMetaData | undefined; - nodes(): Iterable; -} - -export interface MacroMetaData { - path: string | undefined; // in dbt cloud, packages are not downloaded locally - line: number; - character: number; - uniqueId: string; - description?: string; - arguments?: { name: string; type: string; description: string }[]; - name: string; - depends_on: DependsOn; -} - -interface MetricMetaData { - name: string; -} - -export interface NodeMetaData { - uniqueId: string; - path: string | undefined; // in dbt cloud, packages are not downloaded locally - database: string; - schema: string; - alias: string; - name: string; - package_name: string; - description: string; - patch_path: string; - columns: { [columnName: string]: ColumnMetaData }; - config: Config; - resource_type: string; - depends_on: DependsOn; - is_external_project: boolean; - compiled_path: string; - meta: any; -} - -export interface ColumnMetaData { - name: string; - description: string; - data_type: string; - meta: any; -} - -interface Config { - materialized: string; -} - -export interface SourceMetaData { - uniqueId: string; - name: string; - database: string; - schema: string; - tables: SourceTable[]; - package_name: string; - is_external_project: boolean; - meta: any; -} - -export interface SourceTable { - name: string; - identifier: string; - path: string | undefined; // in dbt cloud, packages are not downloaded locally - description: string; - columns: { [columnName: string]: ColumnMetaData }; -} - -interface DocMetaData { - path: string; - line: number; - character: number; -} - -interface TestMetadataSpecification { - column_name: string; - model: string; -} - -// for accepted_values -export interface TestMetadataAcceptedValues extends TestMetadataSpecification { - values?: string[]; -} - -// for relationship -export interface TestMetadataRelationships extends TestMetadataSpecification { - field?: string; - to?: string; -} - -interface DependsOn { - macros: [string]; - nodes: [string]; - sources: [string]; -} - -export interface TestMetaData { - path: string | undefined; // in dbt cloud, packages are not downloaded locally - database: string; - schema: string; - alias: string; - raw_sql: string; - column_name?: string; - test_metadata?: { - kwargs: TestMetadataAcceptedValues | TestMetadataRelationships; - name: string; - namespace?: string; - }; - attached_node?: string; - depends_on: DependsOn; - uniqueId: string; -} - -export interface ExposureMetaData { - description?: string; - depends_on: DependsOn; - label?: string; - maturity?: string; - name: string; - owner: { email: string; name: string }; - tags: [string]; - url?: string; - type: string; - config: { enabled: boolean }; - path: string | undefined; // in dbt cloud, packages are not downloaded locally - unique_id: string; - sources?: [string]; - metrics?: unknown[]; - meta?: Record; -} - -interface NodeGraphMetaData { - currentNode: Node; - nodes: Node[]; -} - -interface ModelGraphMetaData { - uniqueId: string; - name: string; - dependencies?: string[]; -} - -export type NodeGraphMap = Map; -export type ModelGraphMetaMap = Map; - -export interface GraphMetaMap { - parents: NodeGraphMap; - children: NodeGraphMap; - tests: NodeGraphMap; - metrics: NodeGraphMap; -} - -interface IconPath { - light: string; - dark: string; -} - -export abstract class Node { - label: string; - key: string; - url: string | undefined; - iconPath: IconPath = { - light: path.join( - path.resolve(__dirname), - "../media/images/model_light.svg", - ), - dark: path.join(path.resolve(__dirname), "../media/images/model_dark.svg"), - }; - displayInModelTree: boolean = true; - - constructor(label: string, key: string, url?: string) { - this.label = label; - this.key = key; - this.url = url; - } -} - -export class Model extends Node {} - -export class Seed extends Node {} -export class Test extends Node { - // displayInModelTree = false; - iconPath = { - light: path.join( - path.resolve(__dirname), - "../media/images/source_light.svg", - ), - dark: path.join(path.resolve(__dirname), "../media/images/source_dark.svg"), - }; -} -export class Analysis extends Node { - displayInModelTree = true; -} -export class Exposure extends Node { - displayInModelTree = true; -} -export class Metric extends Node { - displayInModelTree = false; -} -export class Snapshot extends Node {} -export class Source extends Node { - iconPath = { - light: path.join( - path.resolve(__dirname), - "../media/images/source_light.svg", - ), - dark: path.join(path.resolve(__dirname), "../media/images/source_dark.svg"), - }; -} - -export enum RunModelType { - RUN_PARENTS, - RUN_CHILDREN, - BUILD_PARENTS, - BUILD_CHILDREN, - BUILD_CHILDREN_PARENTS, - TEST, - SNAPSHOT, -} - -export interface EnvironmentVariables { - [key: string]: string | undefined; -} diff --git a/src/exceptions/executionsExhaustedException.ts b/src/exceptions/executionsExhaustedException.ts deleted file mode 100644 index 9f44ef5b5..000000000 --- a/src/exceptions/executionsExhaustedException.ts +++ /dev/null @@ -1,7 +0,0 @@ -export class ExecutionsExhaustedException extends Error { - constructor(msg: string) { - super(msg); - - Object.setPrototypeOf(this, ExecutionsExhaustedException.prototype); - } -} diff --git a/src/exceptions/index.ts b/src/exceptions/index.ts deleted file mode 100644 index b9f67813e..000000000 --- a/src/exceptions/index.ts +++ /dev/null @@ -1,2 +0,0 @@ -export * from "./rateLimitException"; -export * from "./executionsExhaustedException"; diff --git a/src/exceptions/noCredentialsError.ts b/src/exceptions/noCredentialsError.ts deleted file mode 100644 index 57ba1c2d2..000000000 --- a/src/exceptions/noCredentialsError.ts +++ /dev/null @@ -1,6 +0,0 @@ -export class NoCredentialsError extends Error { - constructor(message: string) { - super(message); - this.name = "NoCredentialsError"; - } -} diff --git a/src/exceptions/rateLimitException.ts b/src/exceptions/rateLimitException.ts deleted file mode 100644 index a9de9a831..000000000 --- a/src/exceptions/rateLimitException.ts +++ /dev/null @@ -1,9 +0,0 @@ -export class RateLimitException extends Error { - public retryAfter: number; - constructor(msg: string, retryAfter: number) { - super(msg); - - this.retryAfter = retryAfter; - Object.setPrototypeOf(this, RateLimitException.prototype); - } -} diff --git a/src/hover_provider/depthDecorationProvider.ts b/src/hover_provider/depthDecorationProvider.ts index 8002a33d9..d9a8828f1 100644 --- a/src/hover_provider/depthDecorationProvider.ts +++ b/src/hover_provider/depthDecorationProvider.ts @@ -12,11 +12,9 @@ import { window, workspace, } from "vscode"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; import { getDepthColor } from "../utils"; -@provideSingleton(DepthDecorationProvider) export class DepthDecorationProvider implements HoverProvider, Disposable { private disposables: Disposable[] = []; private readonly REF_PATTERN = diff --git a/src/hover_provider/index.ts b/src/hover_provider/index.ts index 5d0a2af36..13384ea22 100644 --- a/src/hover_provider/index.ts +++ b/src/hover_provider/index.ts @@ -1,12 +1,10 @@ import { Disposable, languages } from "vscode"; import { DBTPowerUserExtension } from "../dbtPowerUserExtension"; -import { provideSingleton } from "../utils"; +import { DepthDecorationProvider } from "./depthDecorationProvider"; +import { MacroHoverProvider } from "./macroHoverProvider"; import { ModelHoverProvider } from "./modelHoverProvider"; import { SourceHoverProvider } from "./sourceHoverProvider"; -import { MacroHoverProvider } from "./macroHoverProvider"; -import { DepthDecorationProvider } from "./depthDecorationProvider"; -@provideSingleton(HoverProviders) export class HoverProviders implements Disposable { private disposables: Disposable[] = []; diff --git a/src/hover_provider/macroHoverProvider.ts b/src/hover_provider/macroHoverProvider.ts index 323e999b4..ccad6d4b6 100644 --- a/src/hover_provider/macroHoverProvider.ts +++ b/src/hover_provider/macroHoverProvider.ts @@ -1,31 +1,30 @@ +import { + DBTTerminal, + MacroMetaData, + MacroMetaMap, + NodeMetaData, + NodeMetaMap, +} from "@altimateai/dbt-integration"; +import { inject } from "inversify"; import { CancellationToken, - HoverProvider, + Disposable, Hover, + HoverProvider, Position, ProviderResult, TextDocument, - Disposable, } from "vscode"; +import { QueryManifestService } from "../services/queryManifestService"; import { TelemetryService } from "../telemetry"; import { generateMacroHoverMarkdown } from "./utils"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; -import { QueryManifestService } from "../services/queryManifestService"; -import { provideSingleton } from "../utils"; -import { - MacroMetaData, - MacroMetaMap, - NodeMetaData, - NodeMetaMap, - SourceMetaMap, -} from "../domain"; -@provideSingleton(MacroHoverProvider) export class MacroHoverProvider implements HoverProvider, Disposable { private disposables: Disposable[] = []; constructor( private telemetry: TelemetryService, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, private queryManifestService: QueryManifestService, ) {} diff --git a/src/hover_provider/modelHoverProvider.ts b/src/hover_provider/modelHoverProvider.ts index fa740e6a7..7313959b8 100755 --- a/src/hover_provider/modelHoverProvider.ts +++ b/src/hover_provider/modelHoverProvider.ts @@ -1,25 +1,23 @@ +import { DBTTerminal, NodeMetaMap } from "@altimateai/dbt-integration"; +import { inject } from "inversify"; import { CancellationToken, - HoverProvider, Disposable, + Hover, + HoverProvider, + MarkdownString, Position, ProviderResult, Range, TextDocument, Uri, - Hover, - MarkdownString, } from "vscode"; -import { NodeMetaMap } from "../domain"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ManifestCacheChangedEvent } from "../manifest/event/manifestCacheChangedEvent"; -import { provideSingleton } from "../utils"; +import { DBTProject } from "../dbt_client/dbtProject"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { ManifestCacheChangedEvent } from "../dbt_client/event/manifestCacheChangedEvent"; import { TelemetryService } from "../telemetry"; import { generateHoverMarkdownString } from "./utils"; -import { DBTTerminal } from "../dbt_client/dbtTerminal"; -import { DBTProject } from "../manifest/dbtProject"; -@provideSingleton(ModelHoverProvider) export class ModelHoverProvider implements HoverProvider, Disposable { private modelToLocationMap: Map = new Map(); private static readonly IS_REF = /(ref)\([^)]*\)/; @@ -29,6 +27,7 @@ export class ModelHoverProvider implements HoverProvider, Disposable { constructor( private dbtProjectContainer: DBTProjectContainer, private telemetry: TelemetryService, + @inject("DBTTerminal") private dbtTerminal: DBTTerminal, ) { this.disposables.push( diff --git a/src/hover_provider/sourceHoverProvider.ts b/src/hover_provider/sourceHoverProvider.ts index 2427c5f44..19ef46a3f 100755 --- a/src/hover_provider/sourceHoverProvider.ts +++ b/src/hover_provider/sourceHoverProvider.ts @@ -1,22 +1,21 @@ +import { SourceMetaMap } from "@altimateai/dbt-integration"; import { CancellationToken, - HoverProvider, Disposable, + Hover, + HoverProvider, + MarkdownString, Position, ProviderResult, TextDocument, Uri, - Hover, - MarkdownString, } from "vscode"; -import { SourceMetaMap } from "../domain"; -import { DBTProjectContainer } from "../manifest/dbtProjectContainer"; -import { ManifestCacheChangedEvent } from "../manifest/event/manifestCacheChangedEvent"; -import { isEnclosedWithinCodeBlock, provideSingleton } from "../utils"; +import { DBTProjectContainer } from "../dbt_client/dbtProjectContainer"; +import { ManifestCacheChangedEvent } from "../dbt_client/event/manifestCacheChangedEvent"; import { TelemetryService } from "../telemetry"; +import { isEnclosedWithinCodeBlock } from "../utils"; import { generateHoverMarkdownString } from "./utils"; -@provideSingleton(SourceHoverProvider) export class SourceHoverProvider implements HoverProvider, Disposable { private sourceMetaMap: Map = new Map(); private static readonly IS_SOURCE = /(source)\([^)]*\)/; diff --git a/src/hover_provider/utils.ts b/src/hover_provider/utils.ts index 4a7427091..80fa26ab7 100644 --- a/src/hover_provider/utils.ts +++ b/src/hover_provider/utils.ts @@ -1,11 +1,11 @@ -import { MarkdownString, Uri } from "vscode"; import { MacroMetaData, NodeMetaData, NodeMetaType, SourceMetaType, -} from "../domain"; -import { ManifestCacheProjectAddedEvent } from "../manifest/event/manifestCacheChangedEvent"; +} from "@altimateai/dbt-integration"; +import { MarkdownString, Uri } from "vscode"; +import { ManifestCacheProjectAddedEvent } from "../dbt_client/event/manifestCacheChangedEvent"; export function generateHoverMarkdownString( node: NodeMetaType | SourceMetaType, diff --git a/src/inversify.config.ts b/src/inversify.config.ts index cc3065376..7382a0356 100755 --- a/src/inversify.config.ts +++ b/src/inversify.config.ts @@ -1,59 +1,441 @@ -import { Container, interfaces } from "inversify"; -import { buildProviderModule } from "inversify-binding-decorators"; -import { - DiagnosticCollection, - EventEmitter, - Uri, - workspace, - WorkspaceFolder, -} from "vscode"; -import { DBTTerminal } from "./dbt_client/dbtTerminal"; -import { DBTProject } from "./manifest/dbtProject"; -import { - DBTProjectContainer, - ProjectRegisteredUnregisteredEvent, -} from "./manifest/dbtProjectContainer"; -import { DBTWorkspaceFolder } from "./manifest/dbtWorkspaceFolder"; -import { ManifestCacheChangedEvent } from "./manifest/event/manifestCacheChangedEvent"; -import { DBTProjectLogFactory } from "./manifest/modules/dbtProjectLog"; -import { SourceFileWatchersFactory } from "./manifest/modules/sourceFileWatchers"; -import { TargetWatchersFactory } from "./manifest/modules/targetWatchers"; -import { PythonEnvironment } from "./manifest/pythonEnvironment"; -import { TelemetryService } from "./telemetry"; -import { - DBTCoreDetection, - DBTCoreProjectDetection, - DBTCoreProjectIntegration, -} from "./dbt_client/dbtCoreIntegration"; import { + AltimateHttpClient, + ChildrenParentParser, CLIDBTCommandExecutionStrategy, + CommandProcessExecutionFactory, + DBTCloudDetection, + DBTCloudProjectDetection, + DBTCloudProjectIntegration, DBTCommandExecutionInfrastructure, DBTCommandExecutionStrategy, DBTCommandFactory, + DBTConfiguration, + DBTCoreCommandDetection, + DBTCoreCommandProjectDetection, + DBTCoreCommandProjectIntegration, + DBTCoreDetection, + DBTCoreProjectDetection, + DBTCoreProjectIntegration, DBTDetection, + DBTDiagnosticData, + DBTFusionCommandDetection, + DBTFusionCommandProjectDetection, + DBTFusionCommandProjectIntegration, + DbtIntegrationClient, DBTProjectDetection, + DBTProjectIntegrationAdapter, + DBTTerminal, + DeferConfig, + DocParser, + ExposureParser, + GraphParser, + MacroParser, + MetricParser, + ModelDepthParser, + NodeParser, PythonDBTCommandExecutionStrategy, -} from "./dbt_client/dbtIntegration"; -import { - DBTCloudDetection, - DBTCloudProjectDetection, - DBTCloudProjectIntegration, -} from "./dbt_client/dbtCloudIntegration"; -import { CommandProcessExecutionFactory } from "./commandProcessExecution"; + PythonEnvironmentProvider, + RuntimePythonEnvironment, + SourceParser, + TestParser, +} from "@altimateai/dbt-integration"; +import * as LibNamespace from "@lib"; +import { NotebookKernelClient } from "@lib"; +import { Container, interfaces } from "inversify"; +import { Event, EventEmitter, Uri, workspace, WorkspaceFolder } from "vscode"; import { AltimateRequest } from "./altimate"; -import { ValidationProvider } from "./validation_provider"; -import { DeferToProdService } from "./services/deferToProdService"; +import { DBTProject } from "./dbt_client/dbtProject"; +import { ProjectRegisteredUnregisteredEvent } from "./dbt_client/dbtProjectContainer"; +import { DBTProjectLog } from "./dbt_client/dbtProjectLog"; +import { DBTWorkspaceFolder } from "./dbt_client/dbtWorkspaceFolder"; +import { ManifestCacheChangedEvent } from "./dbt_client/event/manifestCacheChangedEvent"; +import { ProjectConfigChangedEvent } from "./dbt_client/event/projectConfigChangedEvent"; +import { PythonEnvironment } from "./dbt_client/pythonEnvironment"; +import { + StaticRuntimePythonEnvironment, + VSCodeRuntimePythonEnvironmentProvider, +} from "./dbt_client/runtimePythonEnvironmentProvider"; +import { VSCodeDBTConfiguration } from "./dbt_client/vscodeConfiguration"; +import { VSCodeDBTTerminal } from "./dbt_client/vscodeTerminal"; +import { AltimateAuthService } from "./services/altimateAuthService"; +import { ConversationService } from "./services/conversationService"; +import { DbtLineageService } from "./services/dbtLineageService"; +import { DbtTestService } from "./services/dbtTestService"; +import { DiagnosticsOutputChannel } from "./services/diagnosticsOutputChannel"; +import { DocGenService } from "./services/docGenService"; +import { FileService } from "./services/fileService"; +import { QueryAnalysisService } from "./services/queryAnalysisService"; +import { QueryManifestService } from "./services/queryManifestService"; import { SharedStateService } from "./services/sharedStateService"; -import { NotebookKernelClient, NotebookDependencies } from "@lib"; -import { DBTCoreCommandProjectIntegration } from "./dbt_client/dbtCoreCommandIntegration"; +import { StreamingService } from "./services/streamingService"; +import { UsersService } from "./services/usersService"; +import { TelemetryService } from "./telemetry"; +import { ValidationProvider } from "./validation_provider"; + +// Core extension components +import { DBTClient } from "./dbt_client"; +import { AltimateDatapilot } from "./dbt_client/datapilot"; +import { DBTProjectContainer } from "./dbt_client/dbtProjectContainer"; +import { DbtPowerUserMcpServer } from "./mcp"; +import { DbtPowerUserMcpServerTools } from "./mcp/server"; + +// Import providers +import { AutocompletionProviders } from "./autocompletion_provider"; +import { DocAutocompletionProvider } from "./autocompletion_provider/docAutocompletionProvider"; +import { MacroAutocompletionProvider } from "./autocompletion_provider/macroAutocompletionProvider"; +import { ModelAutocompletionProvider } from "./autocompletion_provider/modelAutocompletionProvider"; +import { SourceAutocompletionProvider } from "./autocompletion_provider/sourceAutocompletionProvider"; +import { UserCompletionProvider } from "./autocompletion_provider/usercompletion_provider"; +import { CodeLensProviders } from "./code_lens_provider"; +import { CteCodeLensProvider } from "./code_lens_provider/cteCodeLensProvider"; +import { DocumentationCodeLensProvider } from "./code_lens_provider/documentationCodeLensProvider"; +import { SourceModelCreationCodeLensProvider } from "./code_lens_provider/sourceModelCreationCodeLensProvider"; +import { VirtualSqlCodeLensProvider } from "./code_lens_provider/virtualSqlCodeLensProvider"; +import { DefinitionProviders } from "./definition_provider"; +import { DocDefinitionProvider } from "./definition_provider/docDefinitionProvider"; +import { MacroDefinitionProvider } from "./definition_provider/macroDefinitionProvider"; +import { ModelDefinitionProvider } from "./definition_provider/modelDefinitionProvider"; +import { SourceDefinitionProvider } from "./definition_provider/sourceDefinitionProvider"; +import { HoverProviders } from "./hover_provider"; +import { DepthDecorationProvider } from "./hover_provider/depthDecorationProvider"; +import { MacroHoverProvider } from "./hover_provider/macroHoverProvider"; +import { ModelHoverProvider } from "./hover_provider/modelHoverProvider"; +import { SourceHoverProvider } from "./hover_provider/sourceHoverProvider"; +import { ProjectQuickPick } from "./quickpick/projectQuickPick"; + +// Import missing providers and components +import { VSCodeCommands } from "./commands"; +import { AltimateScan } from "./commands/altimateScan"; +import { BigQueryCostEstimate } from "./commands/bigQueryCostEstimate"; +import { RunModel } from "./commands/runModel"; +import { SqlToModel } from "./commands/sqlToModel"; +import { MissingSchemaTest } from "./commands/tests/missingSchemaTest"; +import { StaleModelColumnTest } from "./commands/tests/staleModelColumnTest"; +import { UndocumentedModelColumnTest } from "./commands/tests/undocumentedModelColumnTest"; +import { UnmaterializedModelTest } from "./commands/tests/unmaterializedModelTest"; +import { ValidateSql } from "./commands/validateSql"; +import { WalkthroughCommands } from "./commands/walkthroughCommands"; +import { CommentProviders } from "./comment_provider"; +import { ConversationProvider } from "./comment_provider/conversationProvider"; +import { ContentProviders } from "./content_provider"; +import { SqlPreviewContentProvider } from "./content_provider/sqlPreviewContentProvider"; +import { DBTPowerUserExtension } from "./dbtPowerUserExtension"; +import { DocumentFormattingEditProviders } from "./document_formatting_edit_provider"; +import { DbtDocumentFormattingEditProvider } from "./document_formatting_edit_provider/dbtDocumentFormattingEditProvider"; +import { DbtPowerUserActionsCenter } from "./quickpick"; +import { DbtPowerUserControlCenterAction } from "./quickpick/puQuickPick"; +import { DbtSQLAction } from "./quickpick/sqlQuickPick"; +import { StatusBars } from "./statusbar"; +import { DeferToProductionStatusBar } from "./statusbar/deferToProductionStatusBar"; +import { TargetStatusBar } from "./statusbar/targetStatusBar"; +import { VersionStatusBar } from "./statusbar/versionStatusBar"; +import { TreeviewProviders } from "./treeview_provider"; import { - DBTFusionCommandDetection, - DBTFusionCommandProjectDetection, - DBTFusionCommandProjectIntegration, -} from "./dbt_client/dbtFusionCommandIntegration"; + ChildrenModelTreeview, + DocumentationTreeview, + IconActionsTreeview, + ModelTestTreeview, + ParentModelTreeview, +} from "./treeview_provider/modelTreeviewProvider"; +import { WebviewViewProviders } from "./webview_provider"; +import { DataPilotPanel } from "./webview_provider/datapilotPanel"; +import { DbtDocsView } from "./webview_provider/DbtDocsView"; +import { DocsEditViewPanel } from "./webview_provider/docsEditPanel"; +import { InsightsPanel } from "./webview_provider/insightsPanel"; +import { LineagePanel } from "./webview_provider/lineagePanel"; +import { ModelGraphViewPanel } from "./webview_provider/modelGraphViewPanel"; +import { NewDocsGenPanel } from "./webview_provider/newDocsGenPanel"; +import { NewLineagePanel } from "./webview_provider/newLineagePanel"; +import { QueryResultPanel } from "./webview_provider/queryResultPanel"; +import { SQLLineagePanel } from "./webview_provider/sqlLineagePanel"; export const container = new Container(); -container.load(buildProviderModule()); + +// Bind parser classes +container + .bind(ChildrenParentParser) + .toDynamicValue(() => new ChildrenParentParser()); +container + .bind(NodeParser) + .toDynamicValue( + (context) => new NodeParser(context.container.get("DBTTerminal")), + ); +container + .bind(MacroParser) + .toDynamicValue( + (context) => new MacroParser(context.container.get("DBTTerminal")), + ); +container + .bind(MetricParser) + .toDynamicValue( + (context) => new MetricParser(context.container.get("DBTTerminal")), + ); +container + .bind(GraphParser) + .toDynamicValue( + (context) => new GraphParser(context.container.get("DBTTerminal")), + ); +container + .bind(SourceParser) + .toDynamicValue( + (context) => new SourceParser(context.container.get("DBTTerminal")), + ); +container + .bind(TestParser) + .toDynamicValue( + (context) => new TestParser(context.container.get("DBTTerminal")), + ); +container + .bind(ExposureParser) + .toDynamicValue( + (context) => new ExposureParser(context.container.get("DBTTerminal")), + ); +container + .bind(DocParser) + .toDynamicValue( + (context) => new DocParser(context.container.get("DBTTerminal")), + ); +container + .bind(ModelDepthParser) + .toDynamicValue( + (context) => + new ModelDepthParser( + context.container.get("DBTTerminal"), + context.container.get(DbtIntegrationClient), + context.container.get("DBTConfiguration"), + ), + ); + +// Bind core dbt integration classes using factory functions +container + .bind(CLIDBTCommandExecutionStrategy) + .toDynamicValue(() => { + // Note: CLIDBTCommandExecutionStrategy requires projectRoot and dbtPath at construction time + // These will be provided by the factory functions that create instances + throw new Error( + "CLIDBTCommandExecutionStrategy should be created via Factory", + ); + }) + .inSingletonScope(); + +container + .bind(PythonDBTCommandExecutionStrategy) + .toDynamicValue((context) => { + return new PythonDBTCommandExecutionStrategy( + context.container.get(CommandProcessExecutionFactory), + context.container.get("RuntimePythonEnvironment"), + context.container.get("DBTTerminal"), + context.container.get("DBTConfiguration"), + ); + }) + .inSingletonScope(); + +container.bind(DBTCommandExecutionInfrastructure).toDynamicValue((context) => { + return new DBTCommandExecutionInfrastructure( + context.container.get("RuntimePythonEnvironment"), + context.container.get("DBTTerminal"), + ); +}); + +container + .bind(DBTCommandFactory) + .toDynamicValue((context) => { + return new DBTCommandFactory(context.container.get("DBTConfiguration")); + }) + .inSingletonScope(); + +// Bind dbt core integration classes using factory functions +container + .bind(DBTCoreDetection) + .toDynamicValue((context) => { + return new DBTCoreDetection( + context.container.get("RuntimePythonEnvironment"), + context.container.get(CommandProcessExecutionFactory), + ); + }) + .inSingletonScope(); + +container + .bind(DBTCoreProjectDetection) + .toDynamicValue((context) => { + return new DBTCoreProjectDetection( + context.container.get(DBTCommandExecutionInfrastructure), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +// Note: DBTCoreProjectIntegration requires projectRoot at construction time +// It will be created via Factory +container + .bind(DBTCoreProjectIntegration) + .toDynamicValue(() => { + throw new Error( + "DBTCoreProjectIntegration should be created via Factory", + ); + }) + .inSingletonScope(); + +// Bind dbt cloud integration classes using factory functions +container + .bind(DBTCloudDetection) + .toDynamicValue((context) => { + return new DBTCloudDetection( + context.container.get(CommandProcessExecutionFactory), + context.container.get("RuntimePythonEnvironment"), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +container + .bind(DBTCloudProjectDetection) + .toDynamicValue(() => { + return new DBTCloudProjectDetection(); + }) + .inSingletonScope(); + +// Note: DBTCloudProjectIntegration requires projectRoot at construction time +// It will be created via Factory +container + .bind(DBTCloudProjectIntegration) + .toDynamicValue(() => { + throw new Error( + "DBTCloudProjectIntegration should be created via Factory", + ); + }) + .inSingletonScope(); + +// Bind dbt fusion integration classes using factory functions +container + .bind(DBTFusionCommandDetection) + .toDynamicValue((context) => { + return new DBTFusionCommandDetection( + context.container.get(CommandProcessExecutionFactory), + context.container.get("RuntimePythonEnvironment"), + context.container.get("DBTTerminal"), + context.container.get("DBTConfiguration"), + ); + }) + .inSingletonScope(); + +container + .bind(DBTFusionCommandProjectDetection) + .toDynamicValue(() => { + return new DBTFusionCommandProjectDetection(); + }) + .inSingletonScope(); + +// Note: DBTFusionCommandProjectIntegration requires projectRoot at construction time +// It will be created via Factory +container + .bind(DBTFusionCommandProjectIntegration) + .toDynamicValue(() => { + throw new Error( + "DBTFusionCommandProjectIntegration should be created via Factory", + ); + }) + .inSingletonScope(); + +// Bind dbt core command integration classes using factory functions +container + .bind(DBTCoreCommandDetection) + .toDynamicValue((context) => { + return new DBTCoreCommandDetection( + context.container.get("RuntimePythonEnvironment"), + context.container.get(CommandProcessExecutionFactory), + ); + }) + .inSingletonScope(); + +container + .bind(DBTCoreCommandProjectDetection) + .toDynamicValue((context) => { + return new DBTCoreCommandProjectDetection( + context.container.get(DBTCommandExecutionInfrastructure), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +// Note: DBTCoreCommandProjectIntegration requires projectRoot at construction time +// It will be created via Factory +container + .bind(DBTCoreCommandProjectIntegration) + .toDynamicValue(() => { + throw new Error( + "DBTCoreCommandProjectIntegration should be created via Factory", + ); + }) + .inSingletonScope(); + +// Bind DBTConfiguration +container + .bind("DBTConfiguration") + .to(VSCodeDBTConfiguration) + .inSingletonScope(); + +// Bind DBTTerminal +container + .bind("DBTTerminal") + .to(VSCodeDBTTerminal) + .inSingletonScope(); + +// Bind RuntimePythonEnvironment (VSCode-free version for dbt_integration) +container + .bind("RuntimePythonEnvironment") + .to(StaticRuntimePythonEnvironment) + .inSingletonScope(); + +// Bind PythonEnvironmentProvider +container + .bind("PythonEnvironmentProvider") + .to(VSCodeRuntimePythonEnvironmentProvider) + .inSingletonScope(); + +// Bind CommandProcessExecutionFactory +container + .bind(CommandProcessExecutionFactory) + .toDynamicValue((context) => { + return new CommandProcessExecutionFactory( + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +// Bind AltimateHttpClient +container + .bind(AltimateHttpClient) + .toDynamicValue((context) => { + return new AltimateHttpClient( + context.container.get("DBTTerminal"), + context.container.get("DBTConfiguration"), + ); + }) + .inSingletonScope(); + +// Bind DbtIntegrationClient +container + .bind(DbtIntegrationClient) + .toDynamicValue((context) => { + return new DbtIntegrationClient( + context.container.get(AltimateHttpClient), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +// Bind AltimateRequest +container + .bind(AltimateRequest) + .toDynamicValue((context) => { + return new AltimateRequest( + context.container.get("DBTTerminal"), + context.container.get("DBTConfiguration"), + context.container.get(AltimateHttpClient), + ); + }) + .inSingletonScope(); container .bind>("Factory") @@ -86,6 +468,8 @@ container switch (dbtIntegrationMode) { case "cloud": + // Handle preview features for cloud integration + container.get(AltimateAuthService).handlePreviewFeatures(); return container.get(DBTCloudProjectDetection); case "fusion": return container.get(DBTFusionCommandProjectDetection); @@ -115,7 +499,7 @@ container container.get("Factory"), container.get("Factory"), container.get(TelemetryService), - container.get(DBTTerminal), + container.get("DBTTerminal"), workspaceFolder, _onManifestChanged, _onProjectRegisteredUnregistered, @@ -129,15 +513,14 @@ container >("Factory") .toFactory< CLIDBTCommandExecutionStrategy, - [Uri, string] + [string, string] >((context: interfaces.Context) => { - return (projectRoot: Uri, dbtPath: string) => { + return (projectRoot: string, dbtPath: string) => { const { container } = context; return new CLIDBTCommandExecutionStrategy( container.get(CommandProcessExecutionFactory), - container.get(PythonEnvironment), - container.get(DBTTerminal), - container.get(TelemetryService), + container.get("RuntimePythonEnvironment"), + container.get("DBTTerminal"), projectRoot, dbtPath, ); @@ -150,26 +533,28 @@ container >("Factory") .toFactory< DBTCoreProjectIntegration, - [Uri, DiagnosticCollection] + [string, DBTDiagnosticData[], DeferConfig, () => void] >((context: interfaces.Context) => { return ( - projectRoot: Uri, - projectConfigDiagnostics: DiagnosticCollection, + projectRoot: string, + projectConfigDiagnostics: DBTDiagnosticData[], + deferConfig: DeferConfig, + onDiagnosticsChanged: () => void, ) => { const { container } = context; return new DBTCoreProjectIntegration( container.get(DBTCommandExecutionInfrastructure), - container.get(PythonEnvironment), - container.get(TelemetryService), + container.get("RuntimePythonEnvironment"), + container.get("PythonEnvironmentProvider"), container.get(PythonDBTCommandExecutionStrategy), container.get("Factory"), - container.get(DBTProjectContainer), - container.get(AltimateRequest), - container.get(DBTTerminal), - container.get(ValidationProvider), - container.get(DeferToProdService), + container.get("DBTTerminal"), + container.get("DBTConfiguration"), + container.get(DbtIntegrationClient), projectRoot, projectConfigDiagnostics, + deferConfig, + onDiagnosticsChanged, ); }; }); @@ -180,26 +565,28 @@ container >("Factory") .toFactory< DBTCoreCommandProjectIntegration, - [Uri, DiagnosticCollection] + [string, DBTDiagnosticData[], DeferConfig, () => void] >((context: interfaces.Context) => { return ( - projectRoot: Uri, - projectConfigDiagnostics: DiagnosticCollection, + projectRoot: string, + projectConfigDiagnostics: DBTDiagnosticData[], + deferConfig: DeferConfig, + onDiagnosticsChanged: () => void, ) => { const { container } = context; return new DBTCoreCommandProjectIntegration( container.get(DBTCommandExecutionInfrastructure), - container.get(PythonEnvironment), - container.get(TelemetryService), + container.get("RuntimePythonEnvironment"), + container.get("PythonEnvironmentProvider"), container.get(PythonDBTCommandExecutionStrategy), container.get("Factory"), - container.get(DBTProjectContainer), - container.get(AltimateRequest), - container.get(DBTTerminal), - container.get(ValidationProvider), - container.get(DeferToProdService), + container.get("DBTTerminal"), + container.get("DBTConfiguration"), + container.get(DbtIntegrationClient), projectRoot, projectConfigDiagnostics, + deferConfig, + onDiagnosticsChanged, ); }; }); @@ -210,21 +597,26 @@ container >("Factory") .toFactory< DBTFusionCommandProjectIntegration, - [Uri, DiagnosticCollection] + [string, DBTDiagnosticData[], DeferConfig, () => void] >((context: interfaces.Context) => { - return (projectRoot: Uri) => { + return ( + projectRoot: string, + projectConfigDiagnostics: DBTDiagnosticData[], + deferConfig: DeferConfig, + onDiagnosticsChanged: () => void, + ) => { const { container } = context; return new DBTFusionCommandProjectIntegration( container.get(DBTCommandExecutionInfrastructure), container.get(DBTCommandFactory), container.get("Factory"), - container.get(TelemetryService), - container.get(PythonEnvironment), - container.get(DBTTerminal), - container.get(ValidationProvider), - container.get(DeferToProdService), + container.get("RuntimePythonEnvironment"), + container.get("PythonEnvironmentProvider"), + container.get("DBTTerminal"), projectRoot, - container.get(AltimateRequest), + projectConfigDiagnostics, + deferConfig, + onDiagnosticsChanged, ); }; }); @@ -235,21 +627,60 @@ container >("Factory") .toFactory< DBTCloudProjectIntegration, - [Uri] + [string, DBTDiagnosticData[], DeferConfig, () => void] >((context: interfaces.Context) => { - return (projectRoot: Uri) => { + return ( + projectRoot: string, + projectConfigDiagnostics: DBTDiagnosticData[], + deferConfig: DeferConfig, + onDiagnosticsChanged: () => void, + ) => { const { container } = context; return new DBTCloudProjectIntegration( container.get(DBTCommandExecutionInfrastructure), container.get(DBTCommandFactory), container.get("Factory"), - container.get(TelemetryService), - container.get(PythonEnvironment), - container.get(DBTTerminal), - container.get(ValidationProvider), - container.get(DeferToProdService), + container.get("RuntimePythonEnvironment"), + container.get("PythonEnvironmentProvider"), + container.get("DBTTerminal"), projectRoot, - container.get(AltimateRequest), + projectConfigDiagnostics, + deferConfig, + onDiagnosticsChanged, + ); + }; + }); + +container + .bind< + interfaces.Factory + >("Factory") + .toFactory< + DBTProjectIntegrationAdapter, + [string, DeferConfig | undefined] + >((context: interfaces.Context) => { + return (projectRoot: string, deferConfig: DeferConfig | undefined) => { + const { container } = context; + return new DBTProjectIntegrationAdapter( + container.get("DBTConfiguration"), + container.get(DBTCommandFactory), + container.get("Factory"), + container.get("Factory"), + container.get("Factory"), + container.get("Factory"), + projectRoot, + deferConfig, + container.get(ChildrenParentParser), + container.get(NodeParser), + container.get(MacroParser), + container.get(MetricParser), + container.get(GraphParser), + container.get(SourceParser), + container.get(TestParser), + container.get(ExposureParser), + container.get(DocParser), + container.get("DBTTerminal"), + container.get(ModelDepthParser), ); }; }); @@ -268,19 +699,16 @@ container const { container } = context; return new DBTProject( container.get(PythonEnvironment), - container.get(SourceFileWatchersFactory), - container.get(DBTProjectLogFactory), - container.get(TargetWatchersFactory), + container.get("Factory"), container.get(DBTCommandFactory), - container.get(DBTTerminal), + container.get("DBTTerminal"), container.get(SharedStateService), container.get(TelemetryService), - container.get("Factory"), - container.get("Factory"), - container.get("Factory"), - container.get("Factory"), + container.get(DBTCommandExecutionInfrastructure), + container.get("Factory"), container.get(AltimateRequest), container.get(ValidationProvider), + container.get(AltimateAuthService), path, projectConfig, _onManifestChanged, @@ -288,16 +716,1060 @@ container }; }); +container + .bind>("Factory") + .toFactory]>(() => { + return (onProjectConfigChanged: Event) => { + return new DBTProjectLog(onProjectConfigChanged); + }; + }); + +// Bind services +container + .bind(AltimateAuthService) + .toDynamicValue((context) => { + return new AltimateAuthService(context.container.get("DBTConfiguration")); + }) + .inSingletonScope(); + +container + .bind(ConversationService) + .toDynamicValue((context) => { + return new ConversationService( + context.container.get(QueryManifestService), + context.container.get("DBTTerminal"), + context.container.get(AltimateRequest), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(DbtLineageService) + .toDynamicValue((context) => { + return new DbtLineageService( + context.container.get(AltimateRequest), + context.container.get(TelemetryService), + context.container.get("DBTTerminal"), + context.container.get(QueryManifestService), + ); + }) + .inSingletonScope(); + +container + .bind(DbtTestService) + .toDynamicValue((context) => { + return new DbtTestService( + context.container.get(DocGenService), + context.container.get(StreamingService), + context.container.get(AltimateRequest), + context.container.get(QueryManifestService), + context.container.get("DBTTerminal"), + context.container.get(TelemetryService), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(DiagnosticsOutputChannel) + .toDynamicValue(() => { + return new DiagnosticsOutputChannel(); + }) + .inSingletonScope(); + +container + .bind(DocGenService) + .toDynamicValue((context) => { + return new DocGenService( + context.container.get(AltimateRequest), + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + context.container.get(QueryManifestService), + context.container.get("DBTTerminal"), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(FileService) + .toDynamicValue(() => { + return new FileService(); + }) + .inSingletonScope(); + +container + .bind(QueryAnalysisService) + .toDynamicValue((context) => { + return new QueryAnalysisService( + context.container.get(DocGenService), + context.container.get(StreamingService), + context.container.get(AltimateRequest), + context.container.get(QueryManifestService), + context.container.get("DBTTerminal"), + context.container.get(FileService), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(QueryManifestService) + .toDynamicValue((context) => { + return new QueryManifestService( + context.container.get(DBTProjectContainer), + context.container.get("DBTTerminal"), + context.container.get(SharedStateService), + context.container.get(ProjectQuickPick), + ); + }) + .inSingletonScope(); + +container + .bind(SharedStateService) + .toDynamicValue(() => { + return new SharedStateService(); + }) + .inSingletonScope(); + +container + .bind(ProjectQuickPick) + .toDynamicValue(() => { + return new ProjectQuickPick(); + }) + .inSingletonScope(); + +container + .bind(StreamingService) + .toDynamicValue((context) => { + return new StreamingService( + context.container.get(AltimateRequest), + context.container.get(SharedStateService), + ); + }) + .inSingletonScope(); + +container + .bind(UsersService) + .toDynamicValue((context) => { + return new UsersService( + context.container.get(DBTProjectContainer), + context.container.get("DBTTerminal"), + context.container.get(AltimateRequest), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(TelemetryService) + .toDynamicValue(() => { + return new TelemetryService(); + }) + .inSingletonScope(); + +container + .bind(ValidationProvider) + .toDynamicValue((context) => { + return new ValidationProvider( + context.container.get(AltimateRequest), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +// Bind manifest components +container + .bind(PythonEnvironment) + .toDynamicValue((context) => { + return new PythonEnvironment(context.container.get("DBTTerminal")); + }) + .inSingletonScope(); + +container + .bind(DBTProjectContainer) + .toDynamicValue((context) => { + return new DBTProjectContainer( + context.container.get(DBTClient), + context.container.get("Factory"), + context.container.get("DBTTerminal"), + context.container.get(AltimateDatapilot), + context.container.get(AltimateRequest), + ); + }) + .inSingletonScope(); + +// Bind MCP server tools +container + .bind(DbtPowerUserMcpServerTools) + .toDynamicValue((context) => { + return new DbtPowerUserMcpServerTools( + context.container.get(DBTProjectContainer), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +// Bind MCP server +container + .bind(DbtPowerUserMcpServer) + .toDynamicValue((context) => { + return new DbtPowerUserMcpServer( + context.container.get(DbtPowerUserMcpServerTools), + context.container.get("DBTTerminal"), + context.container.get(AltimateRequest), + context.container.get(SharedStateService), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +// Bind dbt client +container + .bind(DBTClient) + .toDynamicValue((context) => { + return new DBTClient( + context.container.get(PythonEnvironment), + context.container.get("Factory"), + ); + }) + .inSingletonScope(); + +container + .bind(AltimateDatapilot) + .toDynamicValue((context) => { + return new AltimateDatapilot( + context.container.get(PythonEnvironment), + context.container.get(CommandProcessExecutionFactory), + context.container.get("DBTTerminal"), + context.container.get("DBTConfiguration"), + ); + }) + .inSingletonScope(); + +// Bind autocompletion providers +container + .bind(AutocompletionProviders) + .toDynamicValue((context) => { + return new AutocompletionProviders( + context.container.get(MacroAutocompletionProvider), + context.container.get(ModelAutocompletionProvider), + context.container.get(SourceAutocompletionProvider), + context.container.get(DocAutocompletionProvider), + context.container.get(UserCompletionProvider), + ); + }) + .inSingletonScope(); + +container + .bind(DocAutocompletionProvider) + .toDynamicValue((context) => { + return new DocAutocompletionProvider( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + ); + }) + .inSingletonScope(); + +container + .bind(MacroAutocompletionProvider) + .toDynamicValue((context) => { + return new MacroAutocompletionProvider( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + ); + }) + .inSingletonScope(); + +container + .bind(ModelAutocompletionProvider) + .toDynamicValue((context) => { + return new ModelAutocompletionProvider( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + ); + }) + .inSingletonScope(); + +container + .bind(SourceAutocompletionProvider) + .toDynamicValue((context) => { + return new SourceAutocompletionProvider( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + ); + }) + .inSingletonScope(); + +container + .bind(UserCompletionProvider) + .toDynamicValue((context) => { + return new UserCompletionProvider(context.container.get(UsersService)); + }) + .inSingletonScope(); + +// Bind code lens providers +container + .bind(CodeLensProviders) + .toDynamicValue((context) => { + return new CodeLensProviders( + context.container.get(DBTProjectContainer), + context.container.get(SourceModelCreationCodeLensProvider), + context.container.get(VirtualSqlCodeLensProvider), + context.container.get(DocumentationCodeLensProvider), + context.container.get(CteCodeLensProvider), + ); + }) + .inSingletonScope(); + +container + .bind(CteCodeLensProvider) + .toDynamicValue((context) => { + return new CteCodeLensProvider( + context.container.get("DBTTerminal"), + context.container.get(AltimateRequest), + ); + }) + .inSingletonScope(); + +container + .bind(DocumentationCodeLensProvider) + .toDynamicValue(() => { + return new DocumentationCodeLensProvider(); + }) + .inSingletonScope(); + +container + .bind(SourceModelCreationCodeLensProvider) + .toDynamicValue(() => { + return new SourceModelCreationCodeLensProvider(); + }) + .inSingletonScope(); + +container + .bind(VirtualSqlCodeLensProvider) + .toDynamicValue((context) => { + return new VirtualSqlCodeLensProvider( + context.container.get(DBTProjectContainer), + context.container.get(QueryManifestService), + context.container.get("NotebookService"), + ); + }) + .inSingletonScope(); + +// Bind definition providers +container + .bind(DefinitionProviders) + .toDynamicValue((context) => { + return new DefinitionProviders( + context.container.get(ModelDefinitionProvider), + context.container.get(MacroDefinitionProvider), + context.container.get(SourceDefinitionProvider), + context.container.get(DocDefinitionProvider), + ); + }) + .inSingletonScope(); + +container + .bind(DocDefinitionProvider) + .toDynamicValue((context) => { + return new DocDefinitionProvider( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + ); + }) + .inSingletonScope(); + +container + .bind(MacroDefinitionProvider) + .toDynamicValue((context) => { + return new MacroDefinitionProvider( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + ); + }) + .inSingletonScope(); + +container + .bind(ModelDefinitionProvider) + .toDynamicValue((context) => { + return new ModelDefinitionProvider( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +container + .bind(SourceDefinitionProvider) + .toDynamicValue((context) => { + return new SourceDefinitionProvider( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + ); + }) + .inSingletonScope(); + +// Bind hover providers +container + .bind(HoverProviders) + .toDynamicValue((context) => { + return new HoverProviders( + context.container.get(ModelHoverProvider), + context.container.get(SourceHoverProvider), + context.container.get(MacroHoverProvider), + context.container.get(DepthDecorationProvider), + ); + }) + .inSingletonScope(); + +container + .bind(DepthDecorationProvider) + .toDynamicValue((context) => { + return new DepthDecorationProvider( + context.container.get(DBTProjectContainer), + ); + }) + .inSingletonScope(); + +container + .bind(MacroHoverProvider) + .toDynamicValue((context) => { + return new MacroHoverProvider( + context.container.get(TelemetryService), + context.container.get("DBTTerminal"), + context.container.get(QueryManifestService), + ); + }) + .inSingletonScope(); + +container + .bind(ModelHoverProvider) + .toDynamicValue((context) => { + return new ModelHoverProvider( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +container + .bind(SourceHoverProvider) + .toDynamicValue((context) => { + return new SourceHoverProvider( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + ); + }) + .inSingletonScope(); + +// Bind notebook-related services + +container + .bind("NotebookFileSystemProvider") + .toDynamicValue((context) => { + return new LibNamespace.NotebookFileSystemProvider( + context.container.get("DBTTerminal"), + context.container.get(AltimateRequest), + ); + }) + .inSingletonScope(); + container .bind>("Factory") .toFactory((context: interfaces.Context) => { return (path: string) => { const { container } = context; - return new NotebookKernelClient( + return new LibNamespace.NotebookKernelClient( path, container.get(DBTCommandExecutionInfrastructure), - container.get(NotebookDependencies), - container.get(DBTTerminal), + container.get("NotebookDependencies"), + container.get("DBTTerminal"), ); }; }); +container + .bind("NotebookDependencies") + .toDynamicValue((context) => { + return new LibNamespace.NotebookDependencies( + context.container.get("DBTTerminal"), + context.container.get(TelemetryService), + context.container.get(CommandProcessExecutionFactory), + context.container.get(PythonEnvironment), + ); + }) + .inSingletonScope(); + +container + .bind("ClientMapper") + .toDynamicValue((context) => { + return new LibNamespace.ClientMapper( + context.container.get(DBTCommandExecutionInfrastructure), + context.container.get("NotebookDependencies"), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +container + .bind("DatapilotNotebookSerializer") + .toDynamicValue(() => { + return new LibNamespace.DatapilotNotebookSerializer(); + }) + .inSingletonScope(); + +container + .bind("DatapilotNotebookController") + .toDynamicValue((context) => { + return new LibNamespace.DatapilotNotebookController( + context.container.get("ClientMapper"), + context.container.get(QueryManifestService), + context.container.get(TelemetryService), + context.container.get("DBTTerminal"), + context.container.get("NotebookDependencies"), + context.container.get(AltimateRequest), + ); + }) + .inSingletonScope(); + +container + .bind("NotebookService") + .toDynamicValue((context) => { + return new LibNamespace.NotebookService( + context.container.get("DatapilotNotebookController"), + ); + }) + .inSingletonScope(); + +container + .bind("NotebookProviders") + .toDynamicValue((context) => { + return new LibNamespace.NotebookProviders( + context.container.get("DatapilotNotebookSerializer"), + context.container.get("DatapilotNotebookController"), + context.container.get("NotebookFileSystemProvider"), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +// Bind test components +container + .bind(MissingSchemaTest) + .toDynamicValue(() => { + return new MissingSchemaTest(); + }) + .inSingletonScope(); + +container + .bind(UndocumentedModelColumnTest) + .toDynamicValue(() => { + return new UndocumentedModelColumnTest(); + }) + .inSingletonScope(); + +container + .bind(UnmaterializedModelTest) + .toDynamicValue(() => { + return new UnmaterializedModelTest(); + }) + .inSingletonScope(); + +container + .bind(StaleModelColumnTest) + .toDynamicValue(() => { + return new StaleModelColumnTest(); + }) + .inSingletonScope(); + +// Bind additional webview components +container + .bind(DbtDocsView) + .toDynamicValue((context) => { + return new DbtDocsView( + context.container.get(DBTProjectContainer), + context.container.get(AltimateRequest), + context.container.get(TelemetryService), + context.container.get(SharedStateService), + context.container.get("DBTTerminal"), + context.container.get(QueryManifestService), + context.container.get(UsersService), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(SqlPreviewContentProvider) + .toDynamicValue((context) => { + return new SqlPreviewContentProvider( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + ); + }) + .inSingletonScope(); + +container + .bind(DbtDocumentFormattingEditProvider) + .toDynamicValue((context) => { + return new DbtDocumentFormattingEditProvider( + context.container.get(CommandProcessExecutionFactory), + context.container.get(TelemetryService), + context.container.get(PythonEnvironment), + ); + }) + .inSingletonScope(); + +// Bind status bar components +container + .bind(VersionStatusBar) + .toDynamicValue((context) => { + return new VersionStatusBar(context.container.get(DBTProjectContainer)); + }) + .inSingletonScope(); + +container + .bind(DeferToProductionStatusBar) + .toDynamicValue((context) => { + return new DeferToProductionStatusBar( + context.container.get(DBTProjectContainer), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +container + .bind(TargetStatusBar) + .toDynamicValue((context) => { + return new TargetStatusBar( + context.container.get(DBTProjectContainer), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +// Bind quick pick components +container + .bind(DbtPowerUserControlCenterAction) + .toDynamicValue(() => { + return new DbtPowerUserControlCenterAction(); + }) + .inSingletonScope(); + +container + .bind(DbtSQLAction) + .toDynamicValue((context) => { + return new DbtSQLAction(context.container.get(DBTProjectContainer)); + }) + .inSingletonScope(); + +// Bind individual command components that are required by VSCodeCommands +container + .bind(RunModel) + .toDynamicValue((context) => { + return new RunModel(context.container.get(DBTProjectContainer)); + }) + .inSingletonScope(); + +container + .bind(SqlToModel) + .toDynamicValue((context) => { + return new SqlToModel( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + context.container.get(AltimateRequest), + context.container.get("DBTTerminal"), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(ValidateSql) + .toDynamicValue((context) => { + return new ValidateSql( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + context.container.get(AltimateRequest), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +container + .bind(AltimateScan) + .toDynamicValue((context) => { + return new AltimateScan( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + context.container.get(AltimateRequest), + context.container.get(MissingSchemaTest), + context.container.get(UndocumentedModelColumnTest), + context.container.get(UnmaterializedModelTest), + context.container.get(StaleModelColumnTest), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +container + .bind(WalkthroughCommands) + .toDynamicValue((context) => { + return new WalkthroughCommands( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + context.container.get(CommandProcessExecutionFactory), + context.container.get(PythonEnvironment), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +container + .bind(BigQueryCostEstimate) + .toDynamicValue((context) => { + return new BigQueryCostEstimate( + context.container.get(DBTProjectContainer), + context.container.get("DBTTerminal"), + context.container.get(TelemetryService), + ); + }) + .inSingletonScope(); + +container + .bind(ConversationProvider) + .toDynamicValue((context) => { + return new ConversationProvider( + context.container.get(ConversationService), + context.container.get(UsersService), + context.container.get("DBTTerminal"), + context.container.get(SharedStateService), + context.container.get(QueryManifestService), + context.container.get(TelemetryService), + ); + }) + .inSingletonScope(); + +container + .bind(SQLLineagePanel) + .toDynamicValue((context) => { + return new SQLLineagePanel( + context.container.get(DBTProjectContainer), + context.container.get(AltimateRequest), + context.container.get(TelemetryService), + context.container.get("DBTTerminal"), + context.container.get(QueryManifestService), + context.container.get(SharedStateService), + context.container.get(UsersService), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(VSCodeCommands) + .toDynamicValue((context) => { + return new VSCodeCommands( + context.container.get(DBTProjectContainer), + context.container.get(RunModel), + context.container.get(SqlToModel), + context.container.get(ValidateSql), + context.container.get(AltimateScan), + context.container.get(WalkthroughCommands), + context.container.get(BigQueryCostEstimate), + context.container.get("DBTTerminal"), + context.container.get(DiagnosticsOutputChannel), + context.container.get(SharedStateService), + context.container.get(ConversationProvider), + context.container.get(PythonEnvironment), + context.container.get(DBTClient), + context.container.get(SQLLineagePanel), + context.container.get(QueryManifestService), + context.container.get(AltimateRequest), + context.container.get("DatapilotNotebookController"), + ); + }) + .inSingletonScope(); + +// Bind webview panel components +container + .bind(QueryResultPanel) + .toDynamicValue((context) => { + return new QueryResultPanel( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + context.container.get(AltimateRequest), + context.container.get(SharedStateService), + context.container.get("DBTTerminal"), + context.container.get(QueryManifestService), + context.container.get(UsersService), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(DocsEditViewPanel) + .toDynamicValue((context) => { + return new DocsEditViewPanel( + context.container.get(DBTProjectContainer), + context.container.get(AltimateRequest), + context.container.get(TelemetryService), + context.container.get(NewDocsGenPanel), + context.container.get(DocGenService), + context.container.get(DbtTestService), + context.container.get("DBTTerminal"), + context.container.get(DbtLineageService), + ); + }) + .inSingletonScope(); + +container + .bind(LineagePanel) + .toDynamicValue((context) => { + return new LineagePanel( + context.container.get(NewLineagePanel), + context.container.get(ModelGraphViewPanel), + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +container + .bind(DataPilotPanel) + .toDynamicValue((context) => { + return new DataPilotPanel( + context.container.get(DBTProjectContainer), + context.container.get(TelemetryService), + context.container.get(AltimateRequest), + context.container.get(DocGenService), + context.container.get(SharedStateService), + context.container.get(QueryAnalysisService), + context.container.get(QueryManifestService), + context.container.get("DBTTerminal"), + context.container.get(DbtTestService), + context.container.get(FileService), + context.container.get(UsersService), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(InsightsPanel) + .toDynamicValue((context) => { + return new InsightsPanel( + context.container.get(DBTProjectContainer), + context.container.get(AltimateRequest), + context.container.get(DbtIntegrationClient), + context.container.get(TelemetryService), + context.container.get(SharedStateService), + context.container.get("DBTTerminal"), + context.container.get(QueryManifestService), + context.container.get(ValidationProvider), + context.container.get(UsersService), + context.container.get("NotebookFileSystemProvider"), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(NewDocsGenPanel) + .toDynamicValue((context) => { + return new NewDocsGenPanel( + context.container.get(DBTProjectContainer), + context.container.get(AltimateRequest), + context.container.get(TelemetryService), + context.container.get(DocGenService), + context.container.get(SharedStateService), + context.container.get(QueryManifestService), + context.container.get("DBTTerminal"), + context.container.get(DbtTestService), + context.container.get(UsersService), + context.container.get(DbtDocsView), + context.container.get(ConversationProvider), + context.container.get(ConversationService), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(NewLineagePanel) + .toDynamicValue((context) => { + return new NewLineagePanel( + context.container.get(DBTProjectContainer), + context.container.get(AltimateRequest), + context.container.get(TelemetryService), + context.container.get("DBTTerminal"), + context.container.get(DbtLineageService), + context.container.get(SharedStateService), + context.container.get(QueryManifestService), + context.container.get(UsersService), + context.container.get(AltimateAuthService), + ); + }) + .inSingletonScope(); + +container + .bind(ModelGraphViewPanel) + .toDynamicValue((context) => { + return new ModelGraphViewPanel( + context.container.get(DBTProjectContainer), + context.container.get("DBTTerminal"), + ); + }) + .inSingletonScope(); + +// Bind WebviewViewProviders +container + .bind(WebviewViewProviders) + .toDynamicValue((context) => { + return new WebviewViewProviders( + context.container.get(QueryResultPanel), + context.container.get(DocsEditViewPanel), + context.container.get(LineagePanel), + context.container.get(DataPilotPanel), + context.container.get(InsightsPanel), + ); + }) + .inSingletonScope(); + +// Bind treeview components +container + .bind(ChildrenModelTreeview) + .toDynamicValue((context) => { + return new ChildrenModelTreeview( + context.container.get(DBTProjectContainer), + ); + }) + .inSingletonScope(); + +container + .bind(ParentModelTreeview) + .toDynamicValue((context) => { + return new ParentModelTreeview(context.container.get(DBTProjectContainer)); + }) + .inSingletonScope(); + +container + .bind(ModelTestTreeview) + .toDynamicValue((context) => { + return new ModelTestTreeview(context.container.get(DBTProjectContainer)); + }) + .inSingletonScope(); + +container + .bind(DocumentationTreeview) + .toDynamicValue((context) => { + return new DocumentationTreeview( + context.container.get(DBTProjectContainer), + ); + }) + .inSingletonScope(); + +container + .bind(IconActionsTreeview) + .toDynamicValue(() => { + return new IconActionsTreeview(); + }) + .inSingletonScope(); + +// Bind TreeviewProviders +container + .bind(TreeviewProviders) + .toDynamicValue((context) => { + return new TreeviewProviders( + context.container.get(ChildrenModelTreeview), + context.container.get(ParentModelTreeview), + context.container.get(ModelTestTreeview), + context.container.get(DocumentationTreeview), + context.container.get(IconActionsTreeview), + ); + }) + .inSingletonScope(); + +// Bind ContentProviders +container + .bind(ContentProviders) + .toDynamicValue((context) => { + return new ContentProviders( + context.container.get(SqlPreviewContentProvider), + ); + }) + .inSingletonScope(); + +// Bind DocumentFormattingEditProviders +container + .bind(DocumentFormattingEditProviders) + .toDynamicValue((context) => { + return new DocumentFormattingEditProviders( + context.container.get(DbtDocumentFormattingEditProvider), + ); + }) + .inSingletonScope(); + +// Bind StatusBars +container + .bind(StatusBars) + .toDynamicValue((context) => { + return new StatusBars( + context.container.get(VersionStatusBar), + context.container.get(DeferToProductionStatusBar), + context.container.get(TargetStatusBar), + ); + }) + .inSingletonScope(); + +// Bind DbtPowerUserActionsCenter +container + .bind(DbtPowerUserActionsCenter) + .toDynamicValue((context) => { + return new DbtPowerUserActionsCenter( + context.container.get(DbtPowerUserControlCenterAction), + context.container.get(ProjectQuickPick), + context.container.get(DBTProjectContainer), + context.container.get(DbtSQLAction), + ); + }) + .inSingletonScope(); + +// Bind CommentProviders +container + .bind(CommentProviders) + .toDynamicValue((context) => { + return new CommentProviders(context.container.get(ConversationProvider)); + }) + .inSingletonScope(); + +// Finally, bind the main DBTPowerUserExtension +container + .bind(DBTPowerUserExtension) + .toDynamicValue((context) => { + return new DBTPowerUserExtension( + context.container.get(DBTProjectContainer), + context.container.get(WebviewViewProviders), + context.container.get(AutocompletionProviders), + context.container.get(DefinitionProviders), + context.container.get(VSCodeCommands), + context.container.get(TreeviewProviders), + context.container.get(ContentProviders), + context.container.get(CodeLensProviders), + context.container.get(DocumentFormattingEditProviders), + context.container.get(StatusBars), + context.container.get(DbtPowerUserActionsCenter), + context.container.get(TelemetryService), + context.container.get(HoverProviders), + context.container.get(ValidationProvider), + context.container.get(CommentProviders), + context.container.get("NotebookProviders"), + context.container.get(DbtPowerUserMcpServer), + ); + }) + .inSingletonScope(); diff --git a/src/lib/index.d.ts b/src/lib/index.d.ts index 434d5cd5b..90ba45ca0 100644 --- a/src/lib/index.d.ts +++ b/src/lib/index.d.ts @@ -1,118 +1,34 @@ -import { AltimateRequest } from "../dependencies.d.ts"; -import { CancellationToken } from "vscode"; -import { CommandProcessExecutionFactory } from "../../dependencies.d.ts"; -import { DBTCommandExecutionInfrastructure } from "../dependencies.d.ts"; -import { DBTCommandExecutionInfrastructure as DBTCommandExecutionInfrastructure_2 } from "../../dependencies.d.ts"; -import { DBTProject } from "../../dependencies.d.ts"; -import { DBTTerminal } from "../../dependencies.d.ts"; -import { DBTTerminal as DBTTerminal_2 } from "../dependencies.d.ts"; -import { Disposable as Disposable_2 } from "vscode"; -import { Event as Event_2 } from "vscode"; -import { ExecuteSQLResult } from "../dependencies.d.ts"; -import { FileChangeEvent } from "vscode"; -import { FileStat } from "vscode"; -import { FileSystemProvider } from "vscode"; -import { FileType } from "vscode"; +import { + AltimateRequest, + CommandProcessExecutionFactory, + DBTCommandExecutionInfrastructure, + DBTProject, + DBTTerminal, + PythonEnvironment, + QueryExecutionResult, + QueryManifestService, + TelemetryService, +} from "@extension"; import { KernelConnection } from "@jupyterlab/services"; -import { NotebookCell } from "vscode"; -import { NotebookCellKind } from "vscode"; -import { NotebookCellOutput } from "vscode"; -import { NotebookData } from "vscode"; -import { NotebookSerializer } from "vscode"; -import { PythonEnvironment } from "../../dependencies.d.ts"; -import { QueryManifestService } from "../dependencies.d.ts"; -import { TelemetryService } from "../../dependencies.d.ts"; -import { TelemetryService as TelemetryService_2 } from "../dependencies.d.ts"; -import { Uri } from "vscode"; - -declare class ClientMapper { - private executionInfrastructure; - private notebookDependencies; - private dbtTerminal; - private clientMap; - constructor( - executionInfrastructure: DBTCommandExecutionInfrastructure, - notebookDependencies: NotebookDependencies, - dbtTerminal: DBTTerminal_2, - ); - initializeNotebookClient(notebookUri: Uri): Promise; - getNotebookClient(notebookUri: Uri): Promise; -} - -declare interface ColumnConfig { - name: string; - tests: string[]; - [key: string]: any; -} - -declare interface ConnectionSettings { - control_port: number; - hb_port: number; - iopub_port: number; - ip: string; - key: string; - kernel_name: string; - shell_port: number; - signature_scheme: string; - stdin_port: number; - transport: string; -} - -export declare const CustomNotebooks: { - notebooks: PreconfiguredNotebookItem[]; -}; - -export declare class DatapilotNotebookController implements Disposable_2 { - private clientMapper; - private queryManifestService; - private telemetry; - private dbtTerminal; - private notebookDependencies; - private altimate; - private readonly id; - private readonly label; - private _onNotebookCellEvent; - readonly onNotebookCellChangeEvent: Event_2; - private readonly disposables; - private associatedNotebooks; - private executionOrder; - private readonly controller; - constructor( - clientMapper: ClientMapper, - queryManifestService: QueryManifestService, - telemetry: TelemetryService_2, - dbtTerminal: DBTTerminal_2, - notebookDependencies: NotebookDependencies, - altimate: AltimateRequest, - ); - private getNotebookByTemplate; - modelTestSuggestions(args: any): Promise; - generateDbtSourceYaml(args: any): Promise; - generateDbtDbtModelSql(args: any): Promise; - generateDbtDbtModelYaml(args: any): Promise; - generateDbtDbtModelCTE(args: any): Promise; - extractExposuresFromMetabase(args: any): Promise; - extractExposuresFromTableau(args: any): Promise; - private getFileName; - createNotebook(args: OpenNotebookRequest | undefined): Promise; - private sendMessageToPreloadScript; - private getRandomString; - private genUniqueId; - private updateCellId; - private onNotebookClose; - private onDidChangeSelectedNotebooks; - private onNotebookOpen; - private waitForControllerAssociation; - private isControllerAssociatedWithNotebook; - dispose(): void; - private _executeAll; - private filterIPyWidgets; - private updateContextVariablesInKernel; - private _doExecution; -} +import * as vscode from "vscode"; +import { + CancellationToken, + Disposable, + Event, + FileChangeEvent, + FileStat, + FileSystemProvider, + FileType, + NotebookCell, + NotebookCellKind, + NotebookCellOutput, + NotebookData, + NotebookSerializer, + Uri, +} from "vscode"; declare class DatapilotNotebookSerializer - implements NotebookSerializer, Disposable_2 + implements NotebookSerializer, Disposable { dispose(): void; deserializeNotebook( @@ -125,65 +41,61 @@ declare class DatapilotNotebookSerializer ): Promise; } -declare interface DBColumn { - column: string; - dtype: string; +interface NotebookItem { + id: number; + name: string; + data: NotebookSchema; + description: string; + created_on: string; + updated_on: string; + tags: { + id: number; + tag: string; + }[]; + privacy: boolean; } - -export declare interface DbtConfig { - [key: string]: Model[]; +interface NotebookSchema { + cells: NotebookCellSchema[]; + metadata?: Record; } - -export declare const getTestSuggestions: ({ - tableRelation, - sample, - limit, - resourceType, - columnConfig, - excludeTypes, - excludeCols, - tests, - uniquenessCompositeKeyLength, - acceptedValuesMaxCardinality, - rangeStddevs, - stringLengthStddevs, - recencyStddevs, - dbtConfig, - returnObject, - columnsInRelation, - adapter, - queryFn, -}: Props) => Promise; - -export declare interface IPyWidgetMessage { - type: string; - payload: any; +interface NotebookCellSchema { + source: string[]; + cell_type: NotebookCellKind; + languageId: string; + metadata?: Record; + outputs?: NotebookCellOutput[]; } - -export declare interface Model { +interface PreconfiguredNotebookItem { name: string; - columns: ColumnConfig[]; - tests?: any[]; + description: string; + created_at: string; + updated_at: string; + id: string; + tags: string[]; + data: NotebookSchema; } - -export declare interface NotebookCellEvent { +interface NotebookCellEvent { cellId: string; notebook: string; - result?: any; + result?: Record; event: "add" | "update" | "delete"; fragment?: string; languageId: string; } - -export declare interface NotebookCellSchema { - source: string[]; - cell_type: NotebookCellKind; - languageId: string; - metadata?: Record; - outputs?: NotebookCellOutput[]; +interface OpenNotebookRequest { + notebookId?: string; + template?: string; + context?: Record; + notebookSchema?: NotebookSchema; +} +interface NotebookDependency { + type: "dbt" | "python"; + package: string; + name?: string; + version?: string; } -export declare class NotebookDependencies { +declare class NotebookDependencies { private readonly dbtTerminal; private readonly telemetry; private commandProcessExecutionFactory; @@ -207,77 +119,42 @@ export declare class NotebookDependencies { private notebookDependenciesAreInstalled; } -export declare interface NotebookDependency { - type: "dbt" | "python"; - package: string; - name?: string; - version?: string; +interface ConnectionSettings { + control_port: number; + hb_port: number; + iopub_port: number; + ip: string; + key: string; + kernel_name: string; + shell_port: number; + signature_scheme: string; + stdin_port: number; + transport: string; } - -export declare class NotebookFileSystemProvider implements FileSystemProvider { - private dbtTerminal; - private altimate; - private _emitter; - readonly onDidChangeFile: Event_2; - private notebookDataMap; - constructor(dbtTerminal: DBTTerminal_2, altimate: AltimateRequest); - watch( - _uri: Uri, - _options: { - recursive: boolean; - excludes: string[]; - }, - ): Disposable_2; - stat(_uri: Uri): FileStat; - readDirectory(_uri: Uri): [string, FileType][]; - createDirectory(_uri: Uri): void; - private getNotebookData; - readFile(uri: Uri): Promise; - writeFile( - uri: Uri, - content: Uint8Array, - _options: { - create: boolean; - overwrite: boolean; - }, - ): Promise; - delete( - uri: Uri, - _options: { - recursive: boolean; - }, - ): void; - rename( - oldUri: Uri, - newUri: Uri, - _options: { - overwrite: boolean; - }, - ): void; - private getFileNameFromUri; - private customSave; - private saveNotebook; +interface RawKernelType { + realKernel: KernelConnection; + socket: WebSocket; + kernelProcess: { + connection: ConnectionSettings; + pid: number; + }; } -export declare interface NotebookItem { - id: number; - name: string; - data: NotebookSchema; - description: string; - created_on: string; - updated_on: string; - tags: { - id: number; - tag: string; - }[]; - privacy: boolean; +interface WidgetPayload { + [key: string]: unknown; } - -export declare class NotebookKernelClient implements Disposable_2 { +interface IPyWidgetMessage { + type: string; + payload: WidgetPayload; +} +interface NotebookContext { + [key: string]: unknown; +} +declare class NotebookKernelClient implements Disposable { private executionInfrastructure; private notebookDependencies; private dbtTerminal; - get postMessage(): Event_2; + get postMessage(): Event; private _postMessageEmitter; private disposables; private lastUsedStreamOutput?; @@ -295,19 +172,19 @@ export declare class NotebookKernelClient implements Disposable_2 { private versions?; constructor( notebookPath: string, - executionInfrastructure: DBTCommandExecutionInfrastructure_2, + executionInfrastructure: DBTCommandExecutionInfrastructure, notebookDependencies: NotebookDependencies, dbtTerminal: DBTTerminal, ); - isInitialized(): Promise; + isInitialized(): Promise; dispose(): Promise; get jupyterPackagesVersions(): Record | undefined; private getDependenciesVersion; getKernel(): Promise; private initializeNotebookKernel; onKernelSocketResponse(payload: { id: string }): void; - storeContext(context: any): Promise; - storeDataInKernel(cellId: string, data: any): Promise; + storeContext(context: NotebookContext): Promise; + storeDataInKernel(cellId: string, data: unknown): Promise; registerCommTarget(payload: string): Promise; getPythonCodeByType(type: string, cellId: string): Promise; executePython( @@ -329,28 +206,134 @@ export declare class NotebookKernelClient implements Disposable_2 { private handleError; } -export declare class NotebookProviders implements Disposable_2 { - private notebookProvider; - private notebookController; - private notebookFileSystemProvider; +declare class ClientMapper { + private executionInfrastructure; + private notebookDependencies; private dbtTerminal; - private disposables; + private clientMap; constructor( - notebookProvider: DatapilotNotebookSerializer, - notebookController: DatapilotNotebookController, - notebookFileSystemProvider: NotebookFileSystemProvider, - dbtTerminal: DBTTerminal_2, + executionInfrastructure: DBTCommandExecutionInfrastructure, + notebookDependencies: NotebookDependencies, + dbtTerminal: DBTTerminal, ); - private bindNotebookActions; + initializeNotebookClient(notebookUri: Uri): Promise; + getNotebookClient(notebookUri: Uri): Promise; +} + +interface ModelTestArgs { + model: string; + tests?: string[]; +} +interface DbtSourceYamlArgs { + source: string; + schema?: string; + database?: string; +} +interface DbtModelArgs { + model: string; + schema?: string; + database?: string; + description?: string; +} +interface ExposureArgs { + connection: string; + project?: string; +} +declare class DatapilotNotebookController implements Disposable { + private clientMapper; + private queryManifestService; + private telemetry; + private dbtTerminal; + private notebookDependencies; + private altimate; + private readonly id; + private readonly label; + private _onNotebookCellEvent; + readonly onNotebookCellChangeEvent: vscode.Event; + private readonly disposables; + private associatedNotebooks; + private executionOrder; + private readonly controller; + constructor( + clientMapper: ClientMapper, + queryManifestService: QueryManifestService, + telemetry: TelemetryService, + dbtTerminal: DBTTerminal, + notebookDependencies: NotebookDependencies, + altimate: AltimateRequest, + ); + private getNotebookByTemplate; + modelTestSuggestions(args: ModelTestArgs): Promise; + generateDbtSourceYaml(args: DbtSourceYamlArgs): Promise; + generateDbtDbtModelSql(args: DbtModelArgs): Promise; + generateDbtDbtModelYaml(args: DbtModelArgs): Promise; + generateDbtDbtModelCTE(args: DbtModelArgs): Promise; + extractExposuresFromMetabase(args: ExposureArgs): Promise; + extractExposuresFromTableau(args: ExposureArgs): Promise; + private getFileName; + createNotebook(args: OpenNotebookRequest | undefined): Promise; + private sendMessageToPreloadScript; + private getRandomString; + private genUniqueId; + private updateCellId; + private onNotebookClose; + private onDidChangeSelectedNotebooks; + private onNotebookOpen; + private waitForControllerAssociation; + private isControllerAssociatedWithNotebook; dispose(): void; + private _executeAll; + private filterIPyWidgets; + private updateContextVariablesInKernel; + private _doExecution; } -export declare interface NotebookSchema { - cells: NotebookCellSchema[]; - metadata?: Record; +declare class NotebookFileSystemProvider implements FileSystemProvider { + private dbtTerminal; + private altimate; + private _emitter; + readonly onDidChangeFile: Event; + private notebookDataMap; + constructor(dbtTerminal: DBTTerminal, altimate: AltimateRequest); + watch( + _uri: Uri, + _options: { + recursive: boolean; + excludes: string[]; + }, + ): Disposable; + stat(_uri: Uri): FileStat; + readDirectory(_uri: Uri): [string, FileType][]; + createDirectory(_uri: Uri): void; + private getNotebookData; + readFile(uri: Uri): Promise; + writeFile( + uri: Uri, + content: Uint8Array, + _options: { + create: boolean; + overwrite: boolean; + }, + ): Promise; + delete( + uri: Uri, + _options: { + recursive: boolean; + }, + ): void; + rename( + oldUri: Uri, + newUri: Uri, + _options: { + overwrite: boolean; + }, + ): void; + private getFileNameFromUri; + private customSave; + private saveNotebook; } -export declare class NotebookService implements Disposable_2 { +declare class NotebookService implements Disposable { private notebookKernel; private disposables; private cellByNotebookAutocompleteMap; @@ -367,29 +350,59 @@ export declare class NotebookService implements Disposable_2 { private onNotebookCellChanged; } -export declare interface OpenNotebookRequest { - notebookId?: string; - template?: string; - context?: Record; - notebookSchema?: NotebookSchema; +declare const CustomNotebooks: { + notebooks: PreconfiguredNotebookItem[]; +}; + +declare class NotebookProviders implements Disposable { + private notebookProvider; + private notebookController; + private notebookFileSystemProvider; + private dbtTerminal; + private disposables; + constructor( + notebookProvider: DatapilotNotebookSerializer, + notebookController: DatapilotNotebookController, + notebookFileSystemProvider: NotebookFileSystemProvider, + dbtTerminal: DBTTerminal, + ); + private bindNotebookActions; + dispose(): void; } -export declare interface PreconfiguredNotebookItem { +interface DBColumn { + column: string; + dtype: string; +} +interface ColumnConfig$1 { name: string; - description: string; - created_at: string; - updated_at: string; - id: string; - tags: string[]; - data: NotebookSchema; + tests: string[]; + description?: string; + [key: string]: string | string[] | undefined; } +interface Model { + name: string; + columns: ColumnConfig$1[]; + tests?: Array<{ + [key: string]: unknown; + }>; +} +interface DbtConfig { + [key: string]: Model[]; +} +type QueryFn = (query: string) => Promise; -declare interface Props { +interface ColumnConfig { + tests?: string[]; + description?: string; + [key: string]: unknown; +} +interface Props { tableRelation: string; sample?: boolean; limit?: number; resourceType?: string; - columnConfig?: Record; + columnConfig?: Record; excludeTypes?: string[]; excludeCols?: string[]; tests?: ( @@ -404,40 +417,52 @@ declare interface Props { rangeStddevs?: number; stringLengthStddevs?: number; recencyStddevs?: number; - dbtConfig?: Record; + dbtConfig?: DbtConfig; returnObject?: boolean; columnsInRelation: DBColumn[]; adapter: string; queryFn: QueryFn; } +declare const getTestSuggestions: ({ + tableRelation, + sample, + limit, + resourceType, + columnConfig, + excludeTypes, + excludeCols, + tests, + uniquenessCompositeKeyLength, + acceptedValuesMaxCardinality, + rangeStddevs, + stringLengthStddevs, + recencyStddevs, + dbtConfig, + returnObject, + columnsInRelation, + adapter, + queryFn, +}: Props) => Promise; -declare type QueryFn = (query: string) => Promise; - -declare interface RawKernelType { - realKernel: KernelConnection; - socket: any; - kernelProcess: { - connection: ConnectionSettings; - pid: number; - }; -} - -export {}; - -export declare namespace Identifiers { - const GeneratedThemeName = "ipython-theme"; - const MatplotLibDefaultParams = "_VSCode_defaultMatplotlib_Params"; - const MatplotLibFigureFormats = "_VSCode_matplotLib_FigureFormats"; - const DefaultCodeCellMarker = "# %%"; - const DefaultCommTarget = "jupyter.widget"; - const ALL_VARIABLES = "ALL_VARIABLES"; - const KERNEL_VARIABLES = "KERNEL_VARIABLES"; - const DEBUGGER_VARIABLES = "DEBUGGER_VARIABLES"; - const PYTHON_VARIABLES_REQUESTER = "PYTHON_VARIABLES_REQUESTER"; - const MULTIPLEXING_DEBUGSERVICE = "MULTIPLEXING_DEBUGSERVICE"; - const RUN_BY_LINE_DEBUGSERVICE = "RUN_BY_LINE_DEBUGSERVICE"; - const REMOTE_URI = "https://remote/"; - const REMOTE_URI_ID_PARAM = "id"; - const REMOTE_URI_HANDLE_PARAM = "uriHandle"; - const REMOTE_URI_EXTENSION_ID_PARAM = "extensionId"; -} +export { + ClientMapper, + CustomNotebooks, + DatapilotNotebookController, + DatapilotNotebookSerializer, + type DbtConfig, + type IPyWidgetMessage, + type Model, + type NotebookCellEvent, + type NotebookCellSchema, + NotebookDependencies, + type NotebookDependency, + NotebookFileSystemProvider, + type NotebookItem, + NotebookKernelClient, + NotebookProviders, + type NotebookSchema, + NotebookService, + type OpenNotebookRequest, + type PreconfiguredNotebookItem, + getTestSuggestions, +}; diff --git a/src/lib/index.js b/src/lib/index.js index f182d345c..bee2471d0 100644 --- a/src/lib/index.js +++ b/src/lib/index.js @@ -1,39 +1,47 @@ "use strict"; Object.defineProperty(exports, Symbol.toStringTag, { value: "Module" }); -const l = require("vscode"), +const u = require("vscode"), p = require("@extension"), ie = require("python-bridge"), - ue = require("fs"), - de = require("@nteract/messaging/lib/wire-protocol"); -function pe(o) { - const e = Object.create(null, { [Symbol.toStringTag]: { value: "Module" } }); - if (o) { - for (const t in o) - if (t !== "default") { - const n = Object.getOwnPropertyDescriptor(o, t); - Object.defineProperty( - e, - t, - n.get ? n : { enumerable: !0, get: () => o[t] }, - ); - } - } - return (e.default = o), Object.freeze(e); + pe = require("fs"), + ne = require("@jupyterlab/services"); +function L(o, e, t, n) { + var r = arguments.length, + s = + r < 3 ? e : n === null ? (n = Object.getOwnPropertyDescriptor(e, t)) : n, + a; + if (typeof Reflect == "object" && typeof Reflect.decorate == "function") + s = Reflect.decorate(o, e, t, n); + else + for (var i = o.length - 1; i >= 0; i--) + (a = o[i]) && (s = (r < 3 ? a(s) : r > 3 ? a(e, t, s) : a(e, t)) || s); + return r > 3 && s && Object.defineProperty(e, t, s), s; +} +function $(o, e) { + return function (t, n) { + e(t, n, o); + }; +} +function me(o, e, t) { + return Object.defineProperty(o, "name", { configurable: !0, value: e }); } -const te = pe(de), - he = (o) => ("getCells" in o ? o.getCells() : o.cells), - me = (o) => - o instanceof l.NotebookCellData ? o.value : o.document.getText(), - fe = (o) => - o instanceof l.NotebookCellData ? o.languageId : o.document.languageId, - X = (o, e, t) => { +function q(o, e) { + if (typeof Reflect == "object" && typeof Reflect.metadata == "function") + return Reflect.metadata(o, e); +} +const he = (o) => ("getCells" in o ? o.getCells() : o.cells), + be = (o) => + o instanceof u.NotebookCellData ? o.value : o.document.getText(), + ge = (o) => + o instanceof u.NotebookCellData ? o.languageId : o.document.languageId, + Z = (o, e, t) => { var r; const n = []; for (const s of he(o)) n.push({ cell_type: s.kind, - source: me(s).split(/\r?\n/g), - languageId: fe(s), + source: be(s).split(/\r?\n/g), + languageId: ge(s), metadata: s.metadata, outputs: t ? s.outputs : void 0, }); @@ -49,26 +57,14 @@ const te = pe(de), }, }; }, - A = () => Math.random().toString(36).substr(2, 9); -function be() { + x = () => Math.random().toString(36).substr(2, 9); +function fe() { const o = new Date(), e = o.toLocaleDateString("en-GB").replace(/\//g, "-"), t = o.toLocaleTimeString("en-GB", { hour12: !1 }).replace(/:/g, "-"); return `${e}-${t}`; } -var ge = function (o, e, t, n) { - var r = arguments.length, - s = - r < 3 ? e : n === null ? (n = Object.getOwnPropertyDescriptor(e, t)) : n, - a; - if (typeof Reflect == "object" && typeof Reflect.decorate == "function") - s = Reflect.decorate(o, e, t, n); - else - for (var i = o.length - 1; i >= 0; i--) - (a = o[i]) && (s = (r < 3 ? a(s) : r > 3 ? a(e, t, s) : a(e, t)) || s); - return r > 3 && s && Object.defineProperty(e, t, s), s; -}; -let U = class { +class le { dispose() { throw new Error("Method not implemented."); } @@ -81,34 +77,33 @@ let U = class { r = { cells: [] }; } const s = r.cells.map((i) => { - var u; - const c = new l.NotebookCellData( + var c; + const l = new u.NotebookCellData( i.cell_type, - (u = i.source) == null + (c = i.source) == null ? void 0 - : u.join(` + : c.join(` `), i.languageId, ); - return (c.metadata = i.metadata), (c.outputs = i.outputs), c; + return (l.metadata = i.metadata), (l.outputs = i.outputs), l; }), - a = new l.NotebookData(s); + a = new u.NotebookData(s); return (a.metadata = r.metadata), a; } async serializeNotebook(e, t) { - const n = X(e); + const n = Z(e); return new TextEncoder().encode(JSON.stringify(n)); } -}; -U = ge([p.provideSingleton(U)], U); -var P; +} +var I; (function (o) { (o.error = "application/vnd.code.notebook.error"), (o.stderr = "application/vnd.code.notebook.stderr"), (o.stdout = "application/vnd.code.notebook.stdout"); -})(P || (P = {})); -const ye = ["text/plain", "text/markdown", P.stderr, P.stdout], - ne = [ +})(I || (I = {})); +const ye = ["text/plain", "text/markdown", I.stderr, I.stdout], + oe = [ "application/vnd.*", "application/vdom.*", "application/geo+json", @@ -124,14 +119,14 @@ const ye = ["text/plain", "text/markdown", P.stderr, P.stdout], "application/json", "text/plain", ], - D = new Map(); -D.set("display_data", G); -D.set("error", Ce); -D.set("execute_result", G); -D.set("stream", Se); -D.set("update_display_data", G); -function Y(o) { - const e = D.get(o.output_type); + P = new Map(); +P.set("display_data", Y); +P.set("error", Se); +P.set("execute_result", Y); +P.set("stream", Ce); +P.set("update_display_data", Y); +function J(o) { + const e = P.get(o.output_type); let t; return ( e @@ -139,11 +134,11 @@ function Y(o) { : (console.warn( `Unable to translate cell from ${o.output_type} to NotebookCellData for VS Code.`, ), - (t = G(o))), + (t = Y(o))), t ); } -function ee(o) { +function te(o) { const e = { outputType: o.output_type }; switch ((o.transient && (e.transient = o.transient), o.output_type)) { case "display_data": @@ -156,71 +151,73 @@ function ee(o) { } return e; } -function G(o) { - const e = ee(o); +function Y(o) { + const e = te(o); ("image/svg+xml" in o.data || "image/png" in o.data) && (e.__displayOpenPlotIcon = !0); const t = []; if (o.data) for (const n in o.data) t.push(ke(n, o.data[n])); - return new l.NotebookCellOutput(we(t), e); + return new u.NotebookCellOutput(we(t), e); } function we(o) { return o.sort((e, t) => { const n = (a, i) => ( a.endsWith(".*") && (a = a.substr(0, a.indexOf(".*"))), i.startsWith(a) ); - let r = ne.findIndex((a) => n(a, e.mime)), - s = ne.findIndex((a) => n(a, t.mime)); + let r = oe.findIndex((a) => n(a, e.mime)), + s = oe.findIndex((a) => n(a, t.mime)); return ( - oe(e) && (r = -1), - oe(t) && (s = -1), + re(e) && (r = -1), + re(t) && (s = -1), (r = r === -1 ? 100 : r), (s = s === -1 ? 100 : s), r - s ); }); } -function oe(o) { +function re(o) { if (o.mime.startsWith("application/vnd.")) try { return new TextDecoder().decode(o.data).length === 0; - } catch {} + } catch { + console.error(`Failed to decode ${o.mime}`, o.data); + } return !1; } function ke(o, e) { - if (!e) return l.NotebookCellOutputItem.text("", o); + if (!e) return u.NotebookCellOutputItem.text("", o); try { if ( (o.startsWith("text/") || ye.includes(o)) && (Array.isArray(e) || typeof e == "string") ) { - const t = Array.isArray(e) ? z(e) : e; - return l.NotebookCellOutputItem.text(t, o); + const t = Array.isArray(e) ? G(e) : e; + return u.NotebookCellOutputItem.text(t, o); } else return o.startsWith("image/") && typeof e == "string" && o !== "image/svg+xml" - ? new l.NotebookCellOutputItem(_e(e), o) + ? new u.NotebookCellOutputItem(Te(e), o) : typeof e == "object" && e !== null && !Array.isArray(e) - ? l.NotebookCellOutputItem.text(JSON.stringify(e), o) - : ((e = Array.isArray(e) ? z(e) : e), - l.NotebookCellOutputItem.text(e, o)); + ? u.NotebookCellOutputItem.text(JSON.stringify(e), o) + : ((e = Array.isArray(e) ? G(e) : e), + u.NotebookCellOutputItem.text(e, o)); } catch (t) { return ( console.error( `Failed to convert ${o} output to a buffer ${typeof e}, ${e}`, t, ), - l.NotebookCellOutputItem.text("") + u.NotebookCellOutputItem.text("") ); } } -function _e(o) { +function Te(o) { return typeof Buffer < "u" && typeof Buffer.from == "function" ? Buffer.from(o, "base64") : Uint8Array.from(atob(o), (e) => e.charCodeAt(0)); } -function z(o) { +function G(o) { if (Array.isArray(o)) { let e = ""; for (let t = 0; t < o.length; t += 1) { @@ -236,13 +233,13 @@ function z(o) { } return o.toString(); } -function ve(o) { +function _e(o) { let e = o; do (o = e), (e = o.replace(/[^\n]\x08/gm, "")); while (e.length < o.length); return o; } -function Te(o) { +function Ee(o) { for ( o = o.replace( /\r+\n/gm, @@ -259,15 +256,15 @@ function Te(o) { } return o; } -function Ee(o) { - return Te(ve(o)); +function ve(o) { + return Ee(_e(o)); } -function B(o) { +function K(o) { if (o.parent_header && "msg_id" in o.parent_header) return o.parent_header.msg_id; } function Ne(o) { - if (o.hasOwnProperty("text/html")) { + if (Object.prototype.hasOwnProperty.call(o, "text/html")) { const e = o["text/html"]; typeof e == "string" && e.includes('