diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e53648d7c..1c498f610 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,10 @@ repos: -- repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.14.3 - hooks: - # Run the linter. - - id: ruff - args: [ --fix ] - # Run the formatter. - - id: ruff-format \ No newline at end of file + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.14.3 + hooks: + # Run the linter. + - id: ruff-check + args: [--fix] + # Run the formatter. + - id: ruff-format diff --git a/CHANGELOG.md b/CHANGELOG.md index c265715c7..f533156b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add SQL importer support for Teradata databases via `teradata` server type (#986) + ### Fixed - Fix `datacontract init` to generate ODCS format instead of deprecated Data Contract Specification (#984) @@ -20,8 +22,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.11.2] - 2025-12-15 ### Added + - Add Impala engine support for Soda scans via ODCS `impala` server type. + ### Fixed + - Restrict DuckDB dependency to < 1.4.0 (#972) ## [0.11.1] - 2025-12-14 @@ -39,7 +44,7 @@ Not all features that were available are supported in this version, as some feat - Support for `scale` and `precision` (define them in `physicalType`) The reason for this change is that the Data Contract Specification is deprecated, we focus on best possible support for the Open Data Contract Standard. -We try to make this transition as seamless as possible. +We try to make this transition as seamless as possible. If you face issues, please open an issue on GitHub. We continue support reading [Data Contract Specification](https://datacontract-specification.com) data contracts during v0.11.x releases until end of 2026. @@ -66,7 +71,6 @@ To migrate existing data contracts to Open Data Contract Standard use this instr - The `breaking`, `changelog`, and `diff` commands are now deleted (#925). - The `terraform` export format has been removed. - ## [0.10.41] - 2025-12-02 ### Changed @@ -139,7 +143,6 @@ To migrate existing data contracts to Open Data Contract Standard use this instr - export: Excel exporter now exports critical data element - ## [0.10.36] - 2025-10-17 ### Added @@ -157,8 +160,8 @@ To migrate existing data contracts to Open Data Contract Standard use this instr - Export to DQX: Correct DQX format for global-level quality check of data contract export. (#877) - Import the table tags from a open data contract spec v3 (#895) - dbt export: Enhanced model-level primaryKey support with automatic test generation for single and multiple column primary keys (#898) -- ODCS: field discarded when no logicalType defined (#891) - +- ODCS: field discarded when no logicalType defined (#891) + ### Removed - Removed specific linters, as the linters did not support ODCS (#913) @@ -188,60 +191,68 @@ To migrate existing data contracts to Open Data Contract Standard use this instr - Avro Importer: Optional and required enum types are now supported (#804) - ## [0.10.33] - 2025-07-29 ### Added + - Export to Excel: Convert ODCS YAML to Excel https://github.com/datacontract/open-data-contract-standard-excel-template (#742) - Extra properties in Markdown export. (#842) - ## [0.10.32] - 2025-07-28 ### Added + - Import from Excel: Support the new quality sheet ### Fixed + - JUnit Test Report: Fixed incorrect syntax on handling warning test report. (#833) ## [0.10.31] - 2025-07-18 ### Added -- Added support for Variant with Spark exporter, data_contract.test(), and import as source unity catalog (#792) +- Added support for Variant with Spark exporter, data_contract.test(), and import as source unity catalog (#792) ## [0.10.30] - 2025-07-15 ### Fixed + - Excel Import should return ODCS YAML (#829) - Excel Import: Missing server section when the server included a schema property (#823) ### Changed + - Use ` ` instead of ` ` for tab in Markdown export. ## [0.10.29] - 2025-07-06 ### Added + - Support for Data Contract Specification v1.2.0 - `datacontract import --format json`: Import from JSON files ### Changed + - `datacontract api [OPTIONS]`: Added option to pass extra arguments for `uvicorn.run()` ### Fixed + - `pytest tests\test_api.py`: Fixed an issue where special characters were not read correctly from file. - `datacontract export --format mermaid`: Fixed an issue where the `mermaid` export did not handle references correctly ## [0.10.28] - 2025-06-05 ### Added + - Much better ODCS support - - Import anything to ODCS via the `import --spec odcs` flag - - Export to HTML with an ODCS native template via `export --format html` - - Export to Mermaid with an ODCS native mapping via `export --format mermaid` + - Import anything to ODCS via the `import --spec odcs` flag + - Export to HTML with an ODCS native template via `export --format html` + - Export to Mermaid with an ODCS native mapping via `export --format mermaid` - The databricks `unity` importer now supports more than a single table. You can use `--unity-table-full-name` multiple times to import multiple tables. And it will automatically add a server with the catalog and schema name. ### Changed + - `datacontract catalog [OPTIONS]`: Added version to contract cards in `index.html` of the catalog (enabled search by version) - The type mapping of the `unity` importer no uses the native databricks types instead of relying on spark types. This allows for better type mapping and more accurate data contracts. @@ -267,10 +278,10 @@ To migrate existing data contracts to Open Data Contract Standard use this instr - `datacontract export --format sodacl`: Fix resolving server when using `--server` flag (#768) - `datacontract export --format dbt`: Fixed DBT export behaviour of constraints to default to data tests when no model type is specified in the datacontract model - ## [0.10.26] - 2025-05-16 ### Changed + - Databricks: Add support for Variant type (#758) - `datacontract export --format odcs`: Export physical type if the physical type is configured in config object (#757) @@ -280,6 +291,7 @@ To migrate existing data contracts to Open Data Contract Standard use this instr ## [0.10.25] - 2025-05-07 ### Added + - Extracted the DataContractSpecification and the OpenDataContractSpecification in separate pip modules and use them in the CLI. - `datacontract import --format excel`: Import from Excel template https://github.com/datacontract/open-data-contract-standard-excel-template (#742) @@ -315,8 +327,7 @@ To migrate existing data contracts to Open Data Contract Standard use this instr to a file, in a standard format (e.g. JUnit) to improve CI/CD experience (#650) - Added import for `ProtoBuf` -Code for proto to datacontract (#696) - + Code for proto to datacontract (#696) - `dbt` & `dbt-sources` export formats now support the optional `--server` flag to adapt the DBT column `data_type` to specific SQL dialects - Duckdb Connections are now configurable, when used as Python library (#666) @@ -402,6 +413,7 @@ Code for proto to datacontract (#696) - Option to separate physical table name for a model via config option (#270) ### Changed + - JSON Schemas are now bundled with the application (#598) - datacontract export --format html: The model title is now shown if it is different to the model name (#585) @@ -411,8 +423,8 @@ Code for proto to datacontract (#696) - datacontract export --format sql: Create arrays and struct for Databricks (#467) ### Fixed -- datacontract lint: Linter 'Field references existing field' too many values to unpack (expected - 2) (#586) + +- datacontract lint: Linter 'Field references existing field' too many values to unpack (expected 2) (#586) - datacontract test (Azure): Error querying delta tables from azure storage. (#458) - datacontract export --format data-caterer: Use `fields` instead of `schema` - datacontract export --format data-caterer: Use `options` instead of `generator.options` @@ -422,6 +434,7 @@ Code for proto to datacontract (#696) ## [0.10.18] - 2025-01-18 ### Fixed + - Fixed an issue when resolving project's dependencies when all extras are installed. - Definitions referenced by nested fields are not validated correctly (#595) - Replaced deprecated `primary` field with `primaryKey` in exporters, importers, examples, and Jinja templates for backward compatibility. Fixes [#518](https://github.com/your-repo/your-project/issues/518). @@ -430,6 +443,7 @@ Code for proto to datacontract (#696) ## [0.10.17] - 2025-01-16 ### Added + - added export format **markdown**: `datacontract export --format markdown` (#545) - When importing in dbt format, add the dbt unique information as a datacontract unique field (#558) - When importing in dbt format, add the dbt primary key information as a datacontract primaryKey field (#562) @@ -438,28 +452,34 @@ Code for proto to datacontract (#696) - Add serve command on README (#592) ### Changed + - Primary and example fields have been deprecated in Data Contract Specification v1.1.0 (#561) - Define primaryKey and examples for model to follow the changes in datacontract-specification v1.1.0 (#559) ### Fixed + - SQL Server: cannot escape reserved word on model (#557) - Export dbt-staging-sql error on multi models contracts (#587) ### Removed + - OpenTelemetry publisher, as it was hardly used ## [0.10.16] - 2024-12-19 ### Added + - Support for exporting a Data Contract to an Iceberg schema definition. - When importing in dbt format, add the dbt `not_null` information as a datacontract `required` field (#547) ### Changed + - Type conversion when importing contracts into dbt and exporting contracts from dbt (#534) - Ensure 'name' is the first column when exporting in dbt format, considering column attributes (#541) - Rename dbt's `tests` to `data_tests` (#548) ### Fixed + - Modify the arguments to narrow down the import target with `--dbt-model` (#532) - SodaCL: Prevent `KeyError: 'fail'` from happening when testing with SodaCL - fix: populate database and schema values for bigquery in exported dbt sources (#543) @@ -469,6 +489,7 @@ Code for proto to datacontract (#696) ## [0.10.15] - 2024-12-02 ### Added + - Support for model import from parquet file metadata. - Great Expectation export: add optional args (#496) - `suite_name` the name of the expectation suite to export @@ -480,11 +501,13 @@ Code for proto to datacontract (#696) - fixes issue where records with no fields create an invalid bq schema. ### Changed + - Changelog support for custom extension keys in `Models` and `Fields` blocks. - `datacontract catalog --files '*.yaml'` now checks also any subfolders for such files. - Optimize test output table on console if tests fail ### Fixed + - raise valid exception in DataContractSpecification.from_file if file does not exist - Fix importing JSON Schemas containing deeply nested objects without `required` array - SodaCL: Only add data quality tests for executable queries @@ -494,6 +517,7 @@ Code for proto to datacontract (#696) Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. ### Added + - `datacontract test` now also supports ODCS v3 data contract format - `datacontract export --format odcs_v3`: Export to Open Data Contract Standard v3.0.0 (#460) - `datacontract test` now also supports ODCS v3 anda Data Contract SQL quality checks on field and model level @@ -502,25 +526,29 @@ Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. - Support for custom Trino types ### Changed + - `datacontract import --format odcs`: Now supports ODSC v3.0.0 files (#474) - `datacontract export --format odcs`: Now creates v3.0.0 Open Data Contract Standard files (alias to odcs_v3). Old versions are still available as format `odcs_v2`. (#460) ### Fixed -- fix timestamp serialization from parquet -> duckdb (#472) +- fix timestamp serialization from parquet -> duckdb (#472) ## [0.10.13] - 2024-09-20 ### Added + - `datacontract export --format data-caterer`: Export to [Data Caterer YAML](https://data.catering/setup/guide/scenario/data-generation/) ### Changed + - `datacontract export --format jsonschema` handle optional and nullable fields (#409) - `datacontract import --format unity` handle nested and complex fields (#420) - `datacontract import --format spark` handle field descriptions (#420) - `datacontract export --format bigquery` handle bigqueryType (#422) ### Fixed + - use correct float type with bigquery (#417) - Support DATACONTRACT_MANAGER_API_KEY - Some minor bug fixes @@ -528,6 +556,7 @@ Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. ## [0.10.12] - 2024-09-08 ### Added + - Support for import of DBML Models (#379) - `datacontract export --format sqlalchemy`: Export to [SQLAlchemy ORM models](https://docs.sqlalchemy.org/en/20/orm/quickstart.html) (#399) - Support of varchar max length in Glue import (#351) @@ -538,10 +567,12 @@ Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. - Support of enum title on avro import ### Changed + - Deltalake is now using DuckDB's native deltalake support (#258). Extra deltalake removed. - When dumping to YAML (import) the alias name is used instead of the pythonic name. (#373) ### Fixed + - Fix an issue where the datacontract cli fails if installed without any extras (#400) - Fix an issue where Glue database without a location creates invalid data contract (#351) - Fix bigint -> long data type mapping (#351) @@ -551,7 +582,6 @@ Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. - Fix a model bug mismatching with the specification (`definitions.fields`) (#375) - Fix array type management in Spark import (#408) - ## [0.10.11] - 2024-08-08 ### Added @@ -573,22 +603,24 @@ Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. - Fix an issue where the JSON schema validation fails for a field with type `string` and format `uuid` - Fix an issue where common DBML renderers may not be able to parse parts of an exported file - ## [0.10.10] - 2024-07-18 ### Added + - Add support for dbt manifest file (#104) - Fix import of pyspark for type-checking when pyspark isn't required as a module (#312) - Adds support for referencing fields within a definition (#322) - Add `map` and `enum` type for Avro schema import (#311) ### Fixed + - Fix import of pyspark for type-checking when pyspark isn't required as a module (#312)- `datacontract import --format spark`: Import from Spark tables (#326) - Fix an issue where specifying `glue_table` as parameter did not filter the tables and instead returned all tables from `source` database (#333) ## [0.10.9] - 2024-07-03 ### Added + - Add support for Trino (#278) - Spark export: add Spark StructType exporter (#277) - add `--schema` option for the `catalog` and `export` command to provide the schema also locally @@ -597,20 +629,24 @@ Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. - Add support for AWS SESSION_TOKEN (#309) ### Changed + - Added array management on HTML export (#299) ### Fixed + - Fix `datacontract import --format jsonschema` when description is missing (#300) - Fix `datacontract test` with case-sensitive Postgres table names (#310) ## [0.10.8] - 2024-06-19 ### Added + - `datacontract serve` start a local web server to provide a REST-API for the commands - Provide server for sql export for the appropriate schema (#153) - Add struct and array management to Glue export (#271) ### Changed + - Introduced optional dependencies/extras for significantly faster installation times. (#213) - Added delta-lake as an additional optional dependency - support `GOOGLE_APPLICATION_CREDENTIALS` as variable for connecting to bigquery in `datacontract test` @@ -619,14 +655,17 @@ Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. - added the importer factory. This refactoring aims to make it easier to create new importers and consequently the growth and maintainability of the project. (#273) ### Fixed + - `datacontract export --format avro` fixed array structure (#243) ## [0.10.7] - 2024-05-31 ### Added + - Test data contract against dataframes / temporary views (#175) ### Fixed + - AVRO export: Logical Types should be nested (#233) ## [0.10.6] - 2024-05-29 @@ -638,6 +677,7 @@ Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. ## [0.10.5] - 2024-05-29 ### Added + - Added support for `sqlserver` (#196) - `datacontract export --format dbml`: Export to [Database Markup Language (DBML)](https://dbml.dbdiagram.io/home/) (#135) - `datacontract export --format avro`: Now supports config map on field level for logicalTypes and default values [Custom Avro Properties](./README.md#custom-avro-properties) @@ -668,6 +708,7 @@ Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. ## [0.10.3] - 2024-05-05 ### Fixed + - datacontract catalog: Add index.html to manifest ## [0.10.2] - 2024-05-05 @@ -681,8 +722,8 @@ Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. - Added field format information to HTML export ### Fixed -- RDF Export: Fix error if owner is not a URI/URN +- RDF Export: Fix error if owner is not a URI/URN ## [0.10.1] - 2024-04-19 @@ -741,7 +782,6 @@ Data Contract CLI now supports the Open Data Contract Standard (ODCS) v3.0.0. - Fixed a bug where the export to YAML always escaped the unicode characters. - ## [0.9.6-2] - 2024-03-04 ### Added @@ -755,6 +795,7 @@ This is a huge step forward, we now support testing Kafka messages. We start with JSON messages and avro, and Protobuf will follow. ### Added + - test kafka for JSON messages - added import format **sql**: `datacontract import --format sql` (#51) - added export format **dbt-sources**: `datacontract export --format dbt-sources` @@ -765,6 +806,7 @@ We start with JSON messages and avro, and Protobuf will follow. ## [0.9.5] - 2024-02-22 ### Added + - export to dbt models (#37). - export to ODCS (#49). - test - show a test summary table. @@ -773,12 +815,14 @@ We start with JSON messages and avro, and Protobuf will follow. ## [0.9.4] - 2024-02-18 ### Added + - Support for Postgres - Support for Databricks ## [0.9.3] - 2024-02-10 ### Added + - Support for BigQuery data connection - Support for multiple models with S3 @@ -789,6 +833,7 @@ We start with JSON messages and avro, and Protobuf will follow. ## [0.9.2] - 2024-01-31 ### Added + - Publish to Docker Hub ## [0.9.0] - 2024-01-26 - BREAKING @@ -798,15 +843,18 @@ The project migrated from Golang to Python. The Golang version can be found at [cli-go](https://github.com/datacontract/cli-go) ### Added + - `test` Support to directly run tests and connect to data sources defined in servers section. - `test` generated schema tests from the model definition. - `test --publish URL` Publish test results to a server URL. - `export` now exports the data contract so format jsonschema and sodacl. ### Changed + - The `--file` option removed in favor of a direct argument.: Use `datacontract test datacontract.yaml` instead of `datacontract test --file datacontract.yaml`. ### Removed + - `model` is now part of `export` - `quality` is now part of `export` - Temporary Removed: `diff` needs to be migrated to Python. @@ -814,29 +862,40 @@ The Golang version can be found at [cli-go](https://github.com/datacontract/cli- - Temporary Removed: `inline` needs to be migrated to Python. ## [0.6.0] + ### Added + - Support local json schema in lint command. - Update to specification 0.9.2. ## [0.5.3] + ### Fixed + - Fix format flag bug in model (print) command. ## [0.5.2] + ### Changed + - Log to STDOUT. - Rename `model` command parameter, `type` -> `format`. ## [0.5.1] + ### Removed + - Remove `schema` command. ### Fixed + - Fix documentation. - Security update of x/sys. ## [0.5.0] + ### Added + - Adapt Data Contract Specification in version 0.9.2. - Use `models` section for `diff`/`breaking`. - Add `model` command. @@ -844,22 +903,31 @@ The Golang version can be found at [cli-go](https://github.com/datacontract/cli- - Let `quality` write input from STDIN if present. ## [0.4.0] + ### Added + - Basic implementation of `test` command for Soda Core. ### Changed + - Change package structure to allow usage as library. ## [0.3.2] + ### Fixed + - Fix field parsing for dbt models, affects stability of `diff`/`breaking`. ## [0.3.1] + ### Fixed + - Fix comparing order of contracts in `diff`/`breaking`. ## [0.3.0] + ### Added + - Handle non-existent schema specification when using `diff`/`breaking`. - Resolve local and remote resources such as schema specifications when using "$ref: ..." notation. - Implement `schema` command: prints your schema. @@ -867,23 +935,31 @@ The Golang version can be found at [cli-go](https://github.com/datacontract/cli- - Implement the `inline` command: resolves all references using the "$ref: ..." notation and writes them to your data contract. ### Changed + - Allow remote and local location for all data contract inputs (`--file`, `--with`). ## [0.2.0] + ### Added + - Add `diff` command for dbt schema specification. - Add `breaking` command for dbt schema specification. ### Changed + - Suggest a fix during `init` when the file already exists. - Rename `validate` command to `lint`. ### Removed + - Remove `check-compatibility` command. ### Fixed + - Improve usage documentation. ## [0.1.1] + ### Added + - Initial release. diff --git a/datacontract/imports/odcs_helper.py b/datacontract/imports/odcs_helper.py index 5c01daf0a..c0ed4c568 100644 --- a/datacontract/imports/odcs_helper.py +++ b/datacontract/imports/odcs_helper.py @@ -1,6 +1,8 @@ """Helper functions for creating ODCS (OpenDataContractStandard) objects.""" -from typing import Any, Dict, List +from __future__ import annotations + +from typing import Any from open_data_contract_standard.model import ( CustomProperty, @@ -12,8 +14,8 @@ def create_odcs( - id: str = None, - name: str = None, + id: str | None = None, + name: str | None = None, version: str = "1.0.0", status: str = "draft", ) -> OpenDataContractStandard: @@ -31,9 +33,9 @@ def create_odcs( def create_schema_object( name: str, physical_type: str = "table", - description: str = None, - business_name: str = None, - properties: List[SchemaProperty] = None, + description: str | None = None, + business_name: str | None = None, + properties: list[SchemaProperty] | None = None, ) -> SchemaObject: """Create a SchemaObject (equivalent to DCS Model).""" schema = SchemaObject( @@ -54,28 +56,28 @@ def create_schema_object( def create_property( name: str, logical_type: str, - physical_type: str = None, - description: str = None, - required: bool = None, - primary_key: bool = None, - primary_key_position: int = None, - unique: bool = None, - classification: str = None, - tags: List[str] = None, - examples: List[Any] = None, - min_length: int = None, - max_length: int = None, - pattern: str = None, - minimum: float = None, - maximum: float = None, - exclusive_minimum: float = None, - exclusive_maximum: float = None, - precision: int = None, - scale: int = None, - format: str = None, - properties: List["SchemaProperty"] = None, - items: "SchemaProperty" = None, - custom_properties: Dict[str, Any] = None, + physical_type: str | None = None, + description: str | None = None, + required: bool = False, + primary_key: bool = False, + primary_key_position: int | None = None, + unique: bool = False, + classification: str | None = None, + tags: list[str] | None = None, + examples: list[Any] | None = None, + min_length: int | None = None, + max_length: int | None = None, + pattern: str | None = None, + minimum: float | None = None, + maximum: float | None = None, + exclusive_minimum: float | None = None, + exclusive_maximum: float | None = None, + precision: int | None = None, + scale: int | None = None, + format: str | None = None, + properties: list[SchemaProperty] | None = None, + items: SchemaProperty | None = None, + custom_properties: dict[str, Any] | None = None, ) -> SchemaProperty: """Create a SchemaProperty (equivalent to DCS Field).""" prop = SchemaProperty(name=name) @@ -85,7 +87,7 @@ def create_property( prop.physicalType = physical_type if description: prop.description = description - if required is not None: + if required: prop.required = required if primary_key: prop.primaryKey = primary_key @@ -130,9 +132,7 @@ def create_property( # Custom properties if custom_properties: - prop.customProperties = [ - CustomProperty(property=k, value=v) for k, v in custom_properties.items() - ] + prop.customProperties = [CustomProperty(property=k, value=v) for k, v in custom_properties.items()] return prop @@ -140,19 +140,19 @@ def create_property( def create_server( name: str, server_type: str, - environment: str = None, - host: str = None, - port: int = None, - database: str = None, - schema: str = None, - account: str = None, - project: str = None, - dataset: str = None, - path: str = None, - location: str = None, - catalog: str = None, - topic: str = None, - format: str = None, + environment: str | None = None, + host: str | None = None, + port: int | None = None, + database: str | None = None, + schema: str | None = None, + account: str | None = None, + project: str | None = None, + dataset: str | None = None, + path: str | None = None, + location: str | None = None, + catalog: str | None = None, + topic: str | None = None, + format: str | None = None, ) -> Server: """Create a Server object.""" server = Server(server=name, type=server_type) diff --git a/datacontract/imports/sql_importer.py b/datacontract/imports/sql_importer.py index 5d3f4ac0c..40c9c0d8f 100644 --- a/datacontract/imports/sql_importer.py +++ b/datacontract/imports/sql_importer.py @@ -1,9 +1,17 @@ +"""SQL importer for data contracts. + +This module provides functionality to import SQL DDL statements and convert them +into OpenDataContractStandard data contract specifications. +""" + import logging -import os +import pathlib +from typing import Any import sqlglot from open_data_contract_standard.model import OpenDataContractStandard from sqlglot.dialects.dialect import Dialects +from sqlglot.expressions import ColumnDef, Table from datacontract.imports.importer import Importer from datacontract.imports.odcs_helper import ( @@ -15,70 +23,87 @@ from datacontract.model.exceptions import DataContractException from datacontract.model.run import ResultEnum +logger = logging.getLogger(__name__) + class SqlImporter(Importer): - def import_source( - self, source: str, import_args: dict - ) -> OpenDataContractStandard: + """Importer for SQL DDL files.""" + + def import_source(self, source: str, import_args: dict[str, str]) -> OpenDataContractStandard: + """Import source into the data contract specification. + + Args: + source: The source file path. + import_args: Additional import arguments. + + Returns: + The populated data contract specification. + """ return import_sql(self.import_format, source, import_args) -def import_sql( - format: str, source: str, import_args: dict = None -) -> OpenDataContractStandard: +def import_sql(import_format: str, source: str, import_args: dict[str, str] | None = None) -> OpenDataContractStandard: + """Import SQL into the data contract specification. + + Args: + import_format: The type of import format (e.g. "sql"). + source: The source file path. + import_args: Additional import arguments. + + Returns: + The populated data contract specification. + """ sql = read_file(source) - dialect = to_dialect(import_args) + + dialect = None + if import_args is not None and "dialect" in import_args: + dialect = to_dialect(import_args["dialect"]) try: parsed = sqlglot.parse_one(sql=sql, read=dialect) except Exception as e: - logging.error(f"Error parsing SQL: {str(e)}") + logger.exception("Error parsing SQL") raise DataContractException( type="import", name=f"Reading source from {source}", - reason=f"Error parsing SQL: {str(e)}", + reason=f"Error parsing SQL: {e!s}", engine="datacontract", result=ResultEnum.error, - ) + ) from e odcs = create_odcs() odcs.schema_ = [] - server_type = to_server_type(source, dialect) + server_type = to_server_type(dialect) if dialect is not None else None if server_type is not None: odcs.servers = [create_server(name=server_type, server_type=server_type)] - tables = parsed.find_all(sqlglot.expressions.Table) + tables = parsed.find_all(Table) for table in tables: table_name = table.this.name - properties = [] + properties: list[Any] = [] primary_key_position = 1 - for column in parsed.find_all(sqlglot.exp.ColumnDef): - if column.parent.this.name != table_name: + for column in parsed.find_all(ColumnDef): + if column.parent is None or column.parent.this.name != table_name: continue - col_name = column.this.name col_type = to_col_type(column, dialect) - logical_type = map_type_from_sql(col_type) - col_description = get_description(column) - max_length = get_max_length(column) - precision, scale = get_precision_scale(column) is_primary_key = get_primary_key(column) - is_required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None + precision, scale = get_precision_scale(column) prop = create_property( - name=col_name, - logical_type=logical_type, + name=column.this.name, + logical_type=(map_type_from_sql(col_type, dialect) if col_type is not None else "object"), physical_type=col_type, - description=col_description, - max_length=max_length, + description=get_description(column), + max_length=get_max_length(column), precision=precision, scale=scale, primary_key=is_primary_key, primary_key_position=primary_key_position if is_primary_key else None, - required=is_required if is_required else None, + required=column.find(sqlglot.exp.NotNullColumnConstraint) is not None, ) if is_primary_key: @@ -96,51 +121,51 @@ def import_sql( return odcs -def get_primary_key(column) -> bool | None: - if column.find(sqlglot.exp.PrimaryKeyColumnConstraint) is not None: - return True - if column.find(sqlglot.exp.PrimaryKey) is not None: - return True - return None +def get_primary_key(column: ColumnDef) -> bool: + """Determine if the column is a primary key. + Args: + column: The SQLGlot column expression. -def to_dialect(import_args: dict) -> Dialects | None: - if import_args is None: - return None - if "dialect" not in import_args: - return None - dialect = import_args.get("dialect") - if dialect is None: + Returns: + True if primary key, False if not or undetermined. + """ + return ( + column.find(sqlglot.exp.PrimaryKeyColumnConstraint) is not None + or column.find(sqlglot.exp.PrimaryKey) is not None + ) + + +def to_dialect(args_dialect: str | None) -> Dialects | None: + """Convert import arguments to SQLGlot dialect. + + Args: + args_dialect: The dialect string from import arguments. + + Returns: + The corresponding SQLGlot dialect or None if not found. + """ + if args_dialect is None: return None - if dialect == "sqlserver": + if args_dialect.lower() == "sqlserver": return Dialects.TSQL - if dialect.upper() in Dialects.__members__: - return Dialects[dialect.upper()] - if dialect == "sqlserver": - return Dialects.TSQL - return None - - -def to_physical_type_key(dialect: Dialects | str | None) -> str: - dialect_map = { - Dialects.TSQL: "sqlserverType", - Dialects.POSTGRES: "postgresType", - Dialects.BIGQUERY: "bigqueryType", - Dialects.SNOWFLAKE: "snowflakeType", - Dialects.REDSHIFT: "redshiftType", - Dialects.ORACLE: "oracleType", - Dialects.MYSQL: "mysqlType", - Dialects.DATABRICKS: "databricksType", - } - if isinstance(dialect, str): - dialect = Dialects[dialect.upper()] if dialect.upper() in Dialects.__members__ else None - return dialect_map.get(dialect, "physicalType") + elif args_dialect.upper() in Dialects.__members__: + return Dialects[args_dialect.upper()] + else: + logger.warning("Dialect '%s' not recognized, defaulting to None", args_dialect) + return None -def to_server_type(source, dialect: Dialects | None) -> str | None: - if dialect is None: - return None - dialect_map = { +def to_server_type(dialect: Dialects) -> str | None: + """Convert dialect to ODCS object server type. + + Args: + dialect: The SQLGlot dialect. + + Returns: + The corresponding server type or None if not found. + """ + server_type = { Dialects.TSQL: "sqlserver", Dialects.POSTGRES: "postgres", Dialects.BIGQUERY: "bigquery", @@ -149,150 +174,175 @@ def to_server_type(source, dialect: Dialects | None) -> str | None: Dialects.ORACLE: "oracle", Dialects.MYSQL: "mysql", Dialects.DATABRICKS: "databricks", - } - return dialect_map.get(dialect, None) + Dialects.TERADATA: "teradata", + }.get(dialect) + + if server_type is None: + logger.warning("No server type mapping for dialect '%s', defaulting to None", dialect) + return server_type -def to_col_type(column, dialect): +def to_col_type(column: ColumnDef, dialect: Dialects | None) -> str | None: + """Convert column to SQL type string. + + Args: + column: The SQLGlot column expression. + dialect: The SQLGlot dialect. + + Returns: + The SQL type string or None if not found. + """ col_type_kind = column.args["kind"] - if col_type_kind is None: - return None + return col_type_kind.sql(dialect) if col_type_kind is not None else None - return col_type_kind.sql(dialect) +def to_col_type_normalized(column: ColumnDef) -> str | None: + """Convert column to normalized SQL type string. -def to_col_type_normalized(column): - col_type = column.args["kind"].this.name - if col_type is None: + Args: + column: The SQLGlot column expression. + + Returns: + The normalized SQL type string or None if not found. + """ + if column.args["kind"] is None: return None - return col_type.lower() + col_type = column.args["kind"].this.name + return col_type.lower() if col_type is not None else None -def get_description(column: sqlglot.expressions.ColumnDef) -> str | None: - if column.comments is None: - return None - return " ".join(comment.strip() for comment in column.comments) +def get_description(column: ColumnDef) -> str | None: + """Get the description from column comments. + + Args: + column: The SQLGlot column expression. + + Returns: + The description string or None if not found. + """ + return " ".join(comment.strip() for comment in column.comments) if column.comments is not None else None -def get_max_length(column: sqlglot.expressions.ColumnDef) -> int | None: +def get_max_length(column: ColumnDef) -> int | None: + """Get the maximum length from column definition. + + Args: + column: The SQLGlot column expression. + + Returns: + The maximum length or None if not found. + """ col_type = to_col_type_normalized(column) - if col_type is None: - return None - if col_type not in ["varchar", "char", "nvarchar", "nchar"]: + if col_type is None or col_type not in ["varchar", "char", "nvarchar", "nchar"]: return None - col_params = list(column.args["kind"].find_all(sqlglot.expressions.DataTypeParam)) + col_params: list[Any] = list(column.args["kind"].find_all(sqlglot.expressions.DataTypeParam)) max_length_str = None - if len(col_params) == 0: - return None - if len(col_params) == 1: - max_length_str = col_params[0].name - if len(col_params) == 2: - max_length_str = col_params[1].name + match len(col_params): + case 0: + return None + case 1: + max_length_str = col_params[0].name + case 2: + max_length_str = col_params[1].name + if max_length_str is not None: return int(max_length_str) if max_length_str.isdigit() else None -def get_precision_scale(column): +def get_precision_scale(column: ColumnDef) -> tuple[int | None, int | None]: + """Get the precision and scale from column definition. + + Args: + column: The SQLGlot column expression. + + Returns: + The precision and scale or None if not found. + """ col_type = to_col_type_normalized(column) - if col_type is None: - return None, None - if col_type not in ["decimal", "numeric", "float", "number"]: + if col_type is None or col_type not in ["decimal", "numeric", "float", "number"]: return None, None + col_params = list(column.args["kind"].find_all(sqlglot.expressions.DataTypeParam)) - if len(col_params) == 0: - return None, None - if len(col_params) == 1: - if not col_params[0].name.isdigit(): + + match col_params: + case []: return None, None - precision = int(col_params[0].name) - scale = 0 - return precision, scale - if len(col_params) == 2: - if not col_params[0].name.isdigit() or not col_params[1].name.isdigit(): + case [first] if first.name.isdigit(): + return int(first.name), 0 + case [first, second] if first.name.isdigit() and second.name.isdigit(): + return int(first.name), int(second.name) + case _: return None, None - precision = int(col_params[0].name) - scale = int(col_params[1].name) - return precision, scale - return None, None -def map_type_from_sql(sql_type: str) -> str | None: - """Map SQL type to ODCS logical type.""" - if sql_type is None: - return None +def map_type_from_sql(sql_type: str, dialect: Dialects | None = None) -> str: + """Map SQL type to ODCS logical type. + + Args: + sql_type: The SQL type string. + dialect: The SQLGlot dialect (optional). + Returns: + The corresponding ODCS logical type. + """ sql_type_normed = sql_type.lower().strip() - if sql_type_normed.startswith("varchar"): - return "string" - elif sql_type_normed.startswith("char"): - return "string" - elif sql_type_normed.startswith("string"): - return "string" - elif sql_type_normed.startswith("nchar"): - return "string" - elif sql_type_normed.startswith("text"): - return "string" - elif sql_type_normed.startswith("nvarchar"): - return "string" - elif sql_type_normed.startswith("ntext"): + # Check exact matches first + exact_matches: dict[str, str] = { + "date": "date", + "time": "string", + "uniqueidentifier": "string", + "json": "object", + "xml": "string", + } + if sql_type_normed in exact_matches: + return exact_matches[sql_type_normed] + + # Check prefix and set matches + string_types = ("varchar", "char", "string", "nchar", "text", "nvarchar", "ntext") + if sql_type_normed.startswith(string_types) or sql_type_normed in ("clob", "nclob"): return "string" - elif sql_type_normed.startswith("int") and not sql_type_normed.startswith("interval"): - return "integer" - elif sql_type_normed.startswith("bigint"): - return "integer" - elif sql_type_normed.startswith("tinyint"): + + # Handle BYTEINT (Teradata single-byte integer) + if sql_type_normed.startswith("byteint"): return "integer" - elif sql_type_normed.startswith("smallint"): + + if (sql_type_normed.startswith("int") and not sql_type_normed.startswith("interval")) or sql_type_normed.startswith( + ("bigint", "tinyint", "smallint") + ): return "integer" - elif sql_type_normed.startswith("float"): - return "number" - elif sql_type_normed.startswith("double"): - return "number" - elif sql_type_normed.startswith("decimal"): - return "number" - elif sql_type_normed.startswith("numeric"): + + if sql_type_normed.startswith(("float", "double", "decimal", "numeric", "number")): return "number" - elif sql_type_normed.startswith("bool"): - return "boolean" - elif sql_type_normed.startswith("bit"): + + if sql_type_normed.startswith(("bool", "bit")): return "boolean" - elif sql_type_normed.startswith("binary"): - return "array" - elif sql_type_normed.startswith("varbinary"): - return "array" - elif sql_type_normed.startswith("raw"): - return "array" - elif sql_type_normed == "blob" or sql_type_normed == "bfile": + + # Handle INTERVAL types - Oracle as object, others as string + if sql_type_normed.startswith("interval"): + return "object" if dialect == Dialects.ORACLE else "string" + + binary_types = ("binary", "varbinary", "raw", "byte", "varbyte") + if sql_type_normed.startswith(binary_types) or sql_type_normed in ("blob", "bfile"): return "array" - elif sql_type_normed == "date": - return "date" - elif sql_type_normed == "time": - return "string" - elif sql_type_normed.startswith("timestamp"): - return "date" - elif sql_type_normed == "datetime" or sql_type_normed == "datetime2": - return "date" - elif sql_type_normed == "smalldatetime": - return "date" - elif sql_type_normed == "datetimeoffset": + + datetime_types = ("datetime", "datetime2", "smalldatetime", "datetimeoffset") + if sql_type_normed.startswith("timestamp") or sql_type_normed in datetime_types: return "date" - elif sql_type_normed == "uniqueidentifier": # tsql - return "string" - elif sql_type_normed == "json": - return "object" - elif sql_type_normed == "xml": # tsql - return "string" - elif sql_type_normed.startswith("number"): - return "number" - elif sql_type_normed == "clob" or sql_type_normed == "nclob": - return "string" - else: - return "object" + + return "object" -def read_file(path): - if not os.path.exists(path): +def read_file(path: str) -> str: + """Read the content of a file. + + Args: + path: The file path. + + Returns: + The content of the file. + """ + if not pathlib.Path(path).exists(): raise DataContractException( type="import", name=f"Reading source from {path}", @@ -300,6 +350,14 @@ def read_file(path): engine="datacontract", result=ResultEnum.error, ) - with open(path, "r") as file: + with pathlib.Path(path).open("r") as file: file_content = file.read() + if file_content.strip() == "": + raise DataContractException( + type="import", + name=f"Reading source from {path}", + reason=f"The file '{path}' is empty.", + engine="datacontract", + result=ResultEnum.error, + ) return file_content diff --git a/pyproject.toml b/pyproject.toml index c170cdc86..fe1158f84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ dependencies = [ "requests>=2.31,<2.33", "fastjsonschema>=2.19.1,<2.22.0", "fastparquet>=2024.5.0,<2025.0.0", - "numpy>=1.26.4,<2.0.0", # transitive dependency, needs to be <2.0.0 https://github.com/datacontract/datacontract-cli/issues/575 + "numpy>=1.26.4,<2.0.0", # transitive dependency, needs to be <2.0.0 https://github.com/datacontract/datacontract-cli/issues/575 "python-multipart>=0.0.20,<1.0.0", "rich>=13.7,<15.0", "sqlglot>=26.6.0,<29.0.0", @@ -39,21 +39,13 @@ dependencies = [ [project.optional-dependencies] -avro = [ - "avro==1.12.1" -] +avro = ["avro==1.12.1"] -bigquery = [ - "soda-core-bigquery>=3.3.20,<3.6.0", -] +bigquery = ["soda-core-bigquery>=3.3.20,<3.6.0"] -csv = [ - "pandas >= 2.0.0", -] +csv = ["pandas >= 2.0.0"] -excel = [ - "openpyxl>=3.1.5,<4.0.0", -] +excel = ["openpyxl>=3.1.5,<4.0.0"] databricks = [ @@ -64,9 +56,7 @@ databricks = [ "pyspark>=3.5.5,<5.0.0", ] -iceberg = [ - "pyiceberg==0.10.0" -] +iceberg = ["pyiceberg==0.10.0"] kafka = [ "datacontract-cli[avro]", @@ -74,63 +64,37 @@ kafka = [ "pyspark>=3.5.5,<5.0.0", ] -postgres = [ - "soda-core-postgres>=3.3.20,<3.6.0" -] +postgres = ["soda-core-postgres>=3.3.20,<3.6.0"] -s3 = [ - "s3fs>=2025.2.0,<2026.0.0", - "aiobotocore>=2.17.0,<2.26.0", -] +s3 = ["s3fs>=2025.2.0,<2026.0.0", "aiobotocore>=2.17.0,<2.26.0"] snowflake = [ "snowflake-connector-python[pandas]>=3.6,<4.2", - "soda-core-snowflake>=3.3.20,<3.6.0" + "soda-core-snowflake>=3.3.20,<3.6.0", ] -sqlserver = [ - "soda-core-sqlserver>=3.3.20,<3.6.0" -] +sqlserver = ["soda-core-sqlserver>=3.3.20,<3.6.0"] -oracle = [ - "soda-core-oracle>=3.3.20,<3.6.0" -] +oracle = ["soda-core-oracle>=3.3.20,<3.6.0"] -athena = [ - "soda-core-athena>=3.3.20,<3.6.0" -] +athena = ["soda-core-athena>=3.3.20,<3.6.0"] -trino = [ - "soda-core-trino>=3.3.20,<3.6.0" -] +trino = ["soda-core-trino>=3.3.20,<3.6.0"] -dbt = [ - "dbt-core>=1.8.0" -] +dbt = ["dbt-core>=1.8.0"] -dbml = [ - "pydbml>=1.1.1" -] +dbml = ["pydbml>=1.1.1"] -parquet = [ - "pyarrow>=18.1.0" -] +parquet = ["pyarrow>=18.1.0"] -rdf = [ - "rdflib==7.0.0", -] +rdf = ["rdflib==7.0.0"] -api = [ - "fastapi==0.121.2", - "uvicorn==0.38.0", -] +api = ["fastapi==0.121.2", "uvicorn==0.38.0"] -protobuf = [ - "grpcio-tools>=1.53", -] +protobuf = ["grpcio-tools>=1.53"] all = [ - "datacontract-cli[kafka,bigquery,csv,excel,snowflake,postgres,databricks,sqlserver,s3,athena,trino,dbt,dbml,iceberg,parquet,rdf,api,protobuf,oracle]" + "datacontract-cli[kafka,bigquery,csv,excel,snowflake,postgres,databricks,sqlserver,s3,athena,trino,dbt,dbml,iceberg,parquet,rdf,api,protobuf,oracle]", ] # for development, we pin all libraries to an exact version @@ -138,7 +102,7 @@ dev = [ "datacontract-cli[all]", "httpx==0.28.1", "kafka-python", - "minio==7.2.17", # Pin to 7.2.17 - 7.2.18+ has keyword-only params incompatible with testcontainers + "minio==7.2.17", # Pin to 7.2.17 - 7.2.18+ has keyword-only params incompatible with testcontainers "moto==5.1.18", "pandas>=2.1.0", "pre-commit>=3.7.1,<4.6.0", @@ -172,7 +136,7 @@ line-length = 120 [tool.ruff.lint] extend-select = [ - "I", # re-order imports in alphabetic order + "I", # re-order imports in alphabetic order ] [tool.ruff.lint.per-file-ignores] diff --git a/tests/fixtures/teradata/data/data_constraints.sql b/tests/fixtures/teradata/data/data_constraints.sql new file mode 100644 index 000000000..e32c55679 --- /dev/null +++ b/tests/fixtures/teradata/data/data_constraints.sql @@ -0,0 +1,29 @@ +CREATE TABLE customer_location +( + id DECIMAL NOT NULL, + created_by VARCHAR(30) NOT NULL, + create_date TIMESTAMP NOT NULL, + changed_by VARCHAR(30), + change_date TIMESTAMP, + name VARCHAR(120) NOT NULL, + short_name VARCHAR(60), + display_name VARCHAR(120) NOT NULL, + code VARCHAR(30) NOT NULL, + description VARCHAR(4000), + language_id DECIMAL NOT NULL, + status VARCHAR(2) NOT NULL, + CONSTRAINT customer_location_code_key UNIQUE (code), + CONSTRAINT customer_location_pkey PRIMARY KEY (id) +); + +COMMENT ON TABLE customer_location IS 'Table contains records of customer specific Location/address.'; +COMMENT ON COLUMN customer_location.change_date IS 'Date when record is modified.'; +COMMENT ON COLUMN customer_location.changed_by IS 'User who modified record.'; +COMMENT ON COLUMN customer_location.code IS 'Customer location code.'; +COMMENT ON COLUMN customer_location.create_date IS 'Date when record is created.'; +COMMENT ON COLUMN customer_location.created_by IS 'User who created a record.'; +COMMENT ON COLUMN customer_location.description IS 'Description if needed.'; +COMMENT ON COLUMN customer_location.display_name IS 'Display name of the customer location.'; +COMMENT ON COLUMN customer_location.id IS 'Unique identification ID for the record - created by sequence SEQ_CUSTOMER_LOCATION.'; +COMMENT ON COLUMN customer_location.language_id IS 'Language ID. Reference to LANGUAGE table.'; +COMMENT ON COLUMN customer_location.name IS 'Name of the customer location.'; diff --git a/tests/fixtures/teradata/import/ddl.sql b/tests/fixtures/teradata/import/ddl.sql new file mode 100644 index 000000000..5edf0d769 --- /dev/null +++ b/tests/fixtures/teradata/import/ddl.sql @@ -0,0 +1,24 @@ +CREATE TABLE my_table +( + field_primary_key INTEGER PRIMARY KEY, -- Primary key + field_not_null INTEGER NOT NULL, -- Not null + field_byteint BYTEINT, -- Single-byte integer + field_smallint SMALLINT, -- Small integer + field_int INTEGER, -- Regular integer + field_bigint BIGINT, -- Large integer + field_decimal DECIMAL(10, 2), -- Fixed precision decimal + field_numeric NUMERIC(18, 4), -- Numeric type + field_float FLOAT, -- Floating-point number + field_double DOUBLE, -- Double precision float + field_char CHAR(10), -- Fixed-length character + field_varchar VARCHAR(100), -- Variable-length character + field_date DATE, -- Date only (YYYY-MM-DD) + field_time TIME, -- Time only (HH:MM:SS) + field_timestamp TIMESTAMP, -- Date and time + field_interval_year_month INTERVAL YEAR TO MONTH, -- Year-month interval + field_interval_day_second INTERVAL DAY TO SECOND, -- Day-second interval + field_byte BYTE(50), -- Fixed-length byte string + field_varbyte VARBYTE(100), -- Variable-length byte string + field_blob BLOB, -- Binary large object + field_clob CLOB -- Character large object +); diff --git a/tests/test_import_sql_integration.py b/tests/test_import_sql_integration.py new file mode 100644 index 000000000..e80e1dae3 --- /dev/null +++ b/tests/test_import_sql_integration.py @@ -0,0 +1,359 @@ +"""Integration tests for sql_importer module. + +Tests end-to-end functionality of import_sql() function and SqlImporter class, +covering realistic SQL parsing scenarios and ODCS schema generation. +""" + +from unittest.mock import patch + +import pytest + +from datacontract.imports.sql_importer import SqlImporter, import_sql +from datacontract.model.exceptions import DataContractException + + +class TestImportSqlIntegration: + """Integration tests for import_sql() function with realistic SQL.""" + + def test_import_sql_single_table_basic(self, tmp_path): + """Should import simple single-table SQL DDL.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100))") + + result = import_sql("sql", str(sql_file), {"dialect": "postgres"}) + + assert result.schema_ is not None + assert len(result.schema_) >= 1 + schema = result.schema_[0] + assert schema.name == "users" + assert schema.physicalType == "table" + assert schema.properties is not None + assert len(schema.properties) == 2 + + # Verify properties + properties_by_name = {p.name: p for p in schema.properties} + assert "id" in properties_by_name + assert "name" in properties_by_name + assert properties_by_name["id"].logicalType == "integer" + assert properties_by_name["name"].logicalType == "string" + + def test_import_sql_with_server_type(self, tmp_path): + """Should create server entry when dialect provided.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE data (id INT)") + + result = import_sql("sql", str(sql_file), {"dialect": "postgres"}) + + assert result.servers is not None + assert len(result.servers) == 1 + server = result.servers[0] + assert server.server == "postgres" + assert server.type == "postgres" + + def test_import_sql_multiple_tables(self, tmp_path): + """Should import multiple table definitions.""" + sql_file = tmp_path / "test.sql" + sql_content = "CREATE TABLE users (id INT PRIMARY KEY, email VARCHAR(100) NOT NULL)" + sql_file.write_text(sql_content) + + result = import_sql("sql", str(sql_file), {"dialect": "postgres"}) + + # sqlglot.parse_one only parses the first statement + assert result.schema_ is not None + assert len(result.schema_) >= 1 + assert result.schema_[0].name == "users" + + def test_import_sql_primary_key_tracking(self, tmp_path): + """Should correctly track primary key positions.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE composite_key (id INT PRIMARY KEY, alt_id INT PRIMARY KEY, name VARCHAR(50))") + + result = import_sql("sql", str(sql_file)) + + assert result.schema_ is not None + schema = result.schema_[0] + assert schema.properties is not None + properties = schema.properties + pk_properties = [p for p in properties if p.primaryKey] + assert len(pk_properties) == 2 + + # Verify primary key positions are sequential + positions = sorted([p.primaryKeyPosition for p in pk_properties if p.primaryKeyPosition is not None]) + assert positions == [1, 2] + + def test_import_sql_not_null_constraint(self, tmp_path): + """Should capture NOT NULL constraints.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE users (id INT NOT NULL, optional_field VARCHAR(50))") + + result = import_sql("sql", str(sql_file)) + + assert result.schema_ is not None + schema = result.schema_[0] + assert schema.properties is not None + properties = schema.properties + props_by_name = {p.name: p for p in properties} + assert props_by_name["id"].required is True + # required is None when not specified (not False) + optional = props_by_name["optional_field"].required + assert optional is None or optional is False + + def test_import_sql_varchar_with_length(self, tmp_path): + """Should extract varchar length constraints.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE users (name VARCHAR(255), code CHAR(5))") + + result = import_sql("sql", str(sql_file)) + + assert result.schema_ is not None + schema = result.schema_[0] + assert schema.properties is not None + properties = schema.properties + props_by_name = {p.name: p for p in properties} + name_opts = props_by_name["name"].logicalTypeOptions + code_opts = props_by_name["code"].logicalTypeOptions + assert name_opts is not None + assert name_opts.get("maxLength") == 255 + assert code_opts is not None + assert code_opts.get("maxLength") == 5 + + def test_import_sql_decimal_precision_scale(self, tmp_path): + """Should extract precision and scale from decimal types.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE products (price DECIMAL(10, 2))") + + result = import_sql("sql", str(sql_file)) + + assert result.schema_ is not None + schema = result.schema_[0] + assert schema.properties is not None + properties = schema.properties + price_prop = properties[0] + price_opts = price_prop.logicalTypeOptions + assert price_opts is not None + assert price_opts.get("precision") == 10 + assert price_opts.get("scale") == 2 + + def test_import_sql_preserves_physical_type(self, tmp_path): + """Should preserve original SQL type as physicalType (normalized by sqlglot).""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE data (ts TIMESTAMP, data JSON, num NUMERIC(10,2))") + + result = import_sql("sql", str(sql_file)) + + assert result.schema_ is not None + schema = result.schema_[0] + assert schema.properties is not None + properties = schema.properties + props_by_name = {p.name: p for p in properties} + assert props_by_name["ts"].physicalType == "TIMESTAMP" + assert props_by_name["data"].physicalType == "JSON" + # NUMERIC is normalized to DECIMAL with full specification by sqlglot + num_type = props_by_name["num"].physicalType + assert num_type is not None + assert "DECIMAL" in num_type + + def test_import_sql_invalid_sql(self, tmp_path): + """Should raise DataContractException for invalid SQL.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE users (id INT INVALID SYNTAX HERE") + + with pytest.raises(DataContractException) as exc_info: + import_sql("sql", str(sql_file)) + + assert "Error parsing SQL" in exc_info.value.reason + + def test_import_sql_missing_file(self): + """Should raise DataContractException for missing file.""" + with pytest.raises(DataContractException) as exc_info: + import_sql("sql", "/nonexistent/path.sql") + + assert "does not exist" in exc_info.value.reason + + def test_import_sql_without_dialect(self, tmp_path): + """Should work without dialect (None server).""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE users (id INT)") + + result = import_sql("sql", str(sql_file)) + + assert result.schema_ is not None + assert len(result.schema_) == 1 + # Server should be None or empty when no dialect provided + assert result.servers is None or len(result.servers) == 0 + + def test_import_sql_with_all_column_features(self, tmp_path): + """Should handle table with various column features together.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text( + """CREATE TABLE orders ( + id INT PRIMARY KEY, + customer_id INT NOT NULL, + amount DECIMAL(12, 2), + description VARCHAR(500), + created_at TIMESTAMP + )""" + ) + + result = import_sql("sql", str(sql_file), {"dialect": "postgres"}) + + assert result.schema_ is not None + assert len(result.schema_) == 1 + schema = result.schema_[0] + assert schema.properties is not None + assert len(schema.properties) == 5 + + properties = schema.properties + props_by_name = {p.name: p for p in properties} + + # Verify all properties are correctly populated + assert props_by_name["id"].primaryKey is True + assert props_by_name["customer_id"].required is True + amount_opts = props_by_name["amount"].logicalTypeOptions + desc_opts = props_by_name["description"].logicalTypeOptions + assert amount_opts is not None + assert amount_opts.get("precision") == 12 + assert desc_opts is not None + assert desc_opts.get("maxLength") == 500 + assert props_by_name["created_at"].logicalType == "date" + + +class TestSqlImporter: + """Tests for SqlImporter class.""" + + def test_sql_importer_initialization(self): + """Should initialize SqlImporter with import format.""" + importer = SqlImporter("sql") + assert importer.import_format == "sql" + + def test_sql_importer_import_source(self, tmp_path): + """Should import source using import_source method.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100))") + + importer = SqlImporter("sql") + result = importer.import_source(str(sql_file), {"dialect": "postgres"}) + + assert result.schema_ is not None + assert len(result.schema_) == 1 + assert result.schema_[0].name == "users" + + def test_sql_importer_with_tsql_dialect(self, tmp_path): + """Should handle SQL Server specific syntax.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE users (id INT PRIMARY KEY, data NVARCHAR(255) NOT NULL)") + + importer = SqlImporter("sql") + result = importer.import_source(str(sql_file), {"dialect": "sqlserver"}) + + assert result.servers is not None + assert result.servers[0].type == "sqlserver" + assert result.schema_ is not None + schema = result.schema_[0] + assert schema.properties is not None + assert len(schema.properties) == 2 + + def test_sql_importer_handles_import_args(self, tmp_path): + """Should pass import_args through to import_sql.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE data (id INT)") + + importer = SqlImporter("sql") + import_args = {"dialect": "mysql"} + result = importer.import_source(str(sql_file), import_args) + + assert result.servers is not None + assert result.servers[0].type == "mysql" + + def test_sql_importer_without_import_args(self, tmp_path): + """Should handle None import_args gracefully.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE data (id INT)") + + importer = SqlImporter("sql") + result = importer.import_source(str(sql_file), {}) + + assert result.schema_ is not None + assert len(result.schema_) == 1 + + +class TestImportSqlEdgeCases: + """Edge case and error handling tests.""" + + def test_import_sql_empty_sql_creates_empty_schema(self, tmp_path): + """Should raise exception for SQL with no CREATE TABLE statements.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("-- just a comment\n-- nothing here") + + with pytest.raises(DataContractException): + import_sql("sql", str(sql_file)) + + def test_import_sql_mixed_dialects_postgres_to_tsql(self, tmp_path): + """Should parse with specified dialect even if SQL syntax differs.""" + sql_file = tmp_path / "test.sql" + # Standard SQL, parsed as TSQL + sql_file.write_text("CREATE TABLE test (id INT, data VARCHAR(100))") + + result = import_sql("sql", str(sql_file), {"dialect": "tsql"}) + + assert result.servers is not None + assert result.servers[0].type == "sqlserver" + assert result.schema_ is not None + assert len(result.schema_) == 1 + + def test_import_sql_table_with_no_constraints(self, tmp_path): + """Should handle table with minimal constraints.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE minimal (col1 INT, col2 VARCHAR(50), col3 DECIMAL(5, 2))") + + result = import_sql("sql", str(sql_file)) + + assert result.schema_ is not None + schema = result.schema_[0] + assert schema.properties is not None + properties = schema.properties + assert len(properties) == 3 + assert all(p.primaryKey is None or p.primaryKey is False for p in properties) + assert all(p.required is None or p.required is False for p in properties) + + def test_import_sql_various_numeric_types(self, tmp_path): + """Should handle all numeric types correctly.""" + sql_file = tmp_path / "test.sql" + sql_file.write_text( + """CREATE TABLE numbers ( + tiny TINYINT, + small SMALLINT, + regular INT, + big BIGINT, + flt FLOAT, + dbl DOUBLE, + dec DECIMAL(15, 4) + )""" + ) + + result = import_sql("sql", str(sql_file)) + + assert result.schema_ is not None + schema = result.schema_[0] + assert schema.properties is not None + properties = schema.properties + props_by_name = {p.name: p for p in properties} + + # All should map to either integer or number + integer_types = {"tiny", "small", "regular", "big"} + number_types = {"flt", "dbl", "dec"} + + for name in integer_types: + assert props_by_name[name].logicalType == "integer" + for name in number_types: + assert props_by_name[name].logicalType == "number" + + @patch("datacontract.imports.sql_importer.pathlib.Path", spec=True) + def test_import_sql_error_parsing_sql_calls_logger(self, mock_path, tmp_path): + """Should log exceptions when SQL parsing fails.""" + # Create real file that sqlglot can't parse + sql_file = tmp_path / "test.sql" + sql_file.write_text("CREATE TABLE test (id INT VERY INVALID :::)") + + with pytest.raises(DataContractException): + import_sql("sql", str(sql_file)) diff --git a/tests/test_import_sql_teradata.py b/tests/test_import_sql_teradata.py new file mode 100644 index 000000000..94e152665 --- /dev/null +++ b/tests/test_import_sql_teradata.py @@ -0,0 +1,236 @@ +"""Test importing Teradata SQL DDL files into Data Contracts.""" + +import logging + +import yaml +from typer.testing import CliRunner + +from datacontract.cli import app +from datacontract.data_contract import DataContract + +logger = logging.getLogger(__name__) + +sql_file_path = "fixtures/teradata/import/ddl.sql" + + +def test_cli(): + """Test the CLI import command for Teradata SQL DDL files.""" + runner = CliRunner() + result = runner.invoke( + app, + [ + "import", + "--format", + "sql", + "--source", + sql_file_path, + ], + ) + assert result.exit_code == 0 + + +def test_import_sql_teradata(): + """Test importing a Teradata SQL DDL file into a Data Contract.""" + result = DataContract.import_from_source("sql", sql_file_path, dialect="teradata") + + expected = """ +apiVersion: v3.1.0 +kind: DataContract +id: my-data-contract +name: My Data Contract +version: 1.0.0 +status: draft +servers: + - server: teradata + type: teradata +schema: + - name: my_table + physicalType: table + logicalType: object + physicalName: my_table + properties: + - name: field_primary_key + physicalType: INT + description: Primary key + primaryKey: true + primaryKeyPosition: 1 + logicalType: integer + - name: field_not_null + physicalType: INT + description: Not null + logicalType: integer + required: true + - name: field_byteint + physicalType: SMALLINT + description: Single-byte integer + logicalType: integer + - name: field_smallint + physicalType: SMALLINT + description: Small integer + logicalType: integer + - name: field_int + physicalType: INT + description: Regular integer + logicalType: integer + - name: field_bigint + physicalType: BIGINT + description: Large integer + logicalType: integer + - name: field_decimal + physicalType: DECIMAL(10, 2) + description: Fixed precision decimal + logicalType: number + logicalTypeOptions: + precision: 10 + scale: 2 + - name: field_numeric + physicalType: DECIMAL(18, 4) + description: Numeric type + logicalType: number + logicalTypeOptions: + precision: 18 + scale: 4 + - name: field_float + physicalType: FLOAT + description: Floating-point number + logicalType: number + - name: field_double + physicalType: DOUBLE PRECISION + description: Double precision float + logicalType: number + - name: field_char + physicalType: CHAR(10) + description: Fixed-length character + logicalType: string + logicalTypeOptions: + maxLength: 10 + - name: field_varchar + physicalType: VARCHAR(100) + description: Variable-length character + logicalType: string + logicalTypeOptions: + maxLength: 100 + - name: field_date + physicalType: DATE + description: Date only (YYYY-MM-DD) + logicalType: date + - name: field_time + physicalType: TIME + description: Time only (HH:MM:SS) + logicalType: string + - name: field_timestamp + physicalType: TIMESTAMP + description: Date and time + logicalType: date + - name: field_interval_year_month + physicalType: INTERVAL YEAR TO MONTH + description: Year-month interval + logicalType: string + - name: field_interval_day_second + physicalType: INTERVAL DAY TO SECOND + description: Day-second interval + logicalType: string + - name: field_byte + physicalType: TINYINT(50) + description: Fixed-length byte string + logicalType: integer + - name: field_varbyte + physicalType: VARBYTE(100) + description: Variable-length byte string + logicalType: array + - name: field_blob + physicalType: VARBINARY + description: Binary large object + logicalType: array + - name: field_clob + physicalType: TEXT + description: Character large object + logicalType: string + """ + logger.info("Result: %s", result.to_yaml()) + assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) + + +def test_import_sql_constraints(): + """Test importing SQL DDL file with constraints into a Data Contract.""" + result = DataContract.import_from_source("sql", "fixtures/teradata/data/data_constraints.sql", dialect="teradata") + + expected = """ +apiVersion: v3.1.0 +kind: DataContract +id: my-data-contract +name: My Data Contract +version: 1.0.0 +status: draft +servers: + - server: teradata + type: teradata +schema: + - name: customer_location + physicalType: table + logicalType: object + physicalName: customer_location + properties: + - name: id + logicalType: number + physicalType: DECIMAL + required: true + - name: created_by + logicalType: string + logicalTypeOptions: + maxLength: 30 + physicalType: VARCHAR(30) + required: true + - name: create_date + logicalType: date + physicalType: TIMESTAMP + required: true + - name: changed_by + logicalType: string + logicalTypeOptions: + maxLength: 30 + physicalType: VARCHAR(30) + - name: change_date + logicalType: date + physicalType: TIMESTAMP + - name: name + logicalType: string + logicalTypeOptions: + maxLength: 120 + physicalType: VARCHAR(120) + required: true + - name: short_name + logicalType: string + logicalTypeOptions: + maxLength: 60 + physicalType: VARCHAR(60) + - name: display_name + logicalType: string + logicalTypeOptions: + maxLength: 120 + physicalType: VARCHAR(120) + required: true + - name: code + logicalType: string + logicalTypeOptions: + maxLength: 30 + physicalType: VARCHAR(30) + required: true + - name: description + logicalType: string + logicalTypeOptions: + maxLength: 4000 + physicalType: VARCHAR(4000) + - name: language_id + logicalType: number + physicalType: DECIMAL + required: true + - name: status + logicalType: string + logicalTypeOptions: + maxLength: 2 + physicalType: VARCHAR(2) + required: true + """ + logger.info("Result: %s", result.to_yaml()) + assert yaml.safe_load(result.to_yaml()) == yaml.safe_load(expected) diff --git a/tests/test_import_sql_unit.py b/tests/test_import_sql_unit.py new file mode 100644 index 000000000..ece0cdad9 --- /dev/null +++ b/tests/test_import_sql_unit.py @@ -0,0 +1,496 @@ +"""Unit tests for sql_importer helper functions. + +Tests individual functions and helper methods in the sql_importer module, +focusing on isolated unit testing with no external dependencies. +""" + +import pytest +import sqlglot +from sqlglot.dialects.dialect import Dialects +from sqlglot.expressions import ColumnDef + +from datacontract.imports.sql_importer import ( + get_description, + get_max_length, + get_precision_scale, + get_primary_key, + map_type_from_sql, + read_file, + to_col_type, + to_col_type_normalized, + to_dialect, + to_server_type, +) +from datacontract.model.exceptions import DataContractException + + +class TestToDialect: + """Test dialect string to SQLGlot dialect conversion.""" + + def test_to_dialect_with_none(self): + """Should return None when input is None.""" + assert to_dialect(None) is None + + def test_to_dialect_with_sqlserver(self): + """Should convert 'sqlserver' to TSQL dialect.""" + result = to_dialect("sqlserver") + assert result == Dialects.TSQL + + def test_to_dialect_with_uppercase(self): + """Should convert uppercase dialect names.""" + result = to_dialect("POSTGRES") + assert result == Dialects.POSTGRES + + def test_to_dialect_with_lowercase(self): + """Should convert lowercase dialect names.""" + result = to_dialect("postgres") + assert result == Dialects.POSTGRES + + def test_to_dialect_with_mixed_case(self): + """Should convert mixed case dialect names.""" + result = to_dialect("BigQuery") + assert result == Dialects.BIGQUERY + + def test_to_dialect_with_unrecognized_dialect(self, caplog): + """Should return None and log warning for unrecognized dialects.""" + result = to_dialect("unknown_dialect") + assert result is None + assert "not recognized" in caplog.text + + +class TestToServerType: + """Test SQLGlot dialect to server type conversion.""" + + def test_to_server_type_tsql(self): + """Should map TSQL dialect to sqlserver.""" + assert to_server_type(Dialects.TSQL) == "sqlserver" + + def test_to_server_type_postgres(self): + """Should map POSTGRES dialect to postgres.""" + assert to_server_type(Dialects.POSTGRES) == "postgres" + + def test_to_server_type_bigquery(self): + """Should map BIGQUERY dialect to bigquery.""" + assert to_server_type(Dialects.BIGQUERY) == "bigquery" + + def test_to_server_type_snowflake(self): + """Should map SNOWFLAKE dialect to snowflake.""" + assert to_server_type(Dialects.SNOWFLAKE) == "snowflake" + + def test_to_server_type_redshift(self): + """Should map REDSHIFT dialect to redshift.""" + assert to_server_type(Dialects.REDSHIFT) == "redshift" + + def test_to_server_type_oracle(self): + """Should map ORACLE dialect to oracle.""" + assert to_server_type(Dialects.ORACLE) == "oracle" + + def test_to_server_type_mysql(self): + """Should map MYSQL dialect to mysql.""" + assert to_server_type(Dialects.MYSQL) == "mysql" + + def test_to_server_type_databricks(self): + """Should map DATABRICKS dialect to databricks.""" + assert to_server_type(Dialects.DATABRICKS) == "databricks" + + def test_to_server_type_teradata(self): + """Should map TERADATA dialect to teradata.""" + assert to_server_type(Dialects.TERADATA) == "teradata" + + def test_to_server_type_unmapped_dialect(self, caplog): + """Should return None and log warning for unmapped dialects.""" + result = to_server_type(Dialects.SPARK) + assert result is None + assert "No server type mapping" in caplog.text + + +class TestGetPrimaryKey: + """Test primary key detection in column definitions.""" + + def test_get_primary_key_with_primary_key_constraint(self): + """Should detect PrimaryKeyColumnConstraint.""" + sql = "CREATE TABLE t (id INT PRIMARY KEY)" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + assert get_primary_key(column) is True + + def test_get_primary_key_without_constraint(self): + """Should return False when no primary key constraint.""" + sql = "CREATE TABLE t (id INT)" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + assert get_primary_key(column) is False + + def test_get_primary_key_with_not_null(self): + """Should not confuse NOT NULL with PRIMARY KEY.""" + sql = "CREATE TABLE t (id INT NOT NULL)" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + assert get_primary_key(column) is False + + +class TestToColType: + """Test column type extraction.""" + + def test_to_col_type_integer(self): + """Should extract integer type.""" + sql = "CREATE TABLE t (id INT)" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + col_type = to_col_type(column, Dialects.POSTGRES) + assert col_type == "INT" + + def test_to_col_type_varchar(self): + """Should extract varchar type with length.""" + sql = "CREATE TABLE t (name VARCHAR(100))" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + col_type = to_col_type(column, Dialects.POSTGRES) + assert col_type is not None + assert "VARCHAR" in col_type + + def test_to_col_type_decimal_with_precision(self): + """Should extract decimal type with precision and scale.""" + sql = "CREATE TABLE t (amount DECIMAL(10, 2))" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + col_type = to_col_type(column, Dialects.POSTGRES) + assert col_type is not None + assert "DECIMAL" in col_type + + def test_to_col_type_with_dialect_conversion(self): + """Should use dialect for type conversion.""" + sql = "CREATE TABLE t (data NVARCHAR(50))" + parsed = sqlglot.parse_one(sql, read="tsql") + column = next(iter(parsed.find_all(ColumnDef))) + col_type = to_col_type(column, Dialects.TSQL) + assert col_type is not None + + +class TestToColTypeNormalized: + """Test normalized column type extraction.""" + + def test_to_col_type_normalized_int(self): + """Should normalize INT to lowercase.""" + sql = "CREATE TABLE t (id INT)" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + col_type = to_col_type_normalized(column) + assert col_type == "int" + + def test_to_col_type_normalized_varchar(self): + """Should normalize VARCHAR to lowercase.""" + sql = "CREATE TABLE t (name VARCHAR(100))" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + col_type = to_col_type_normalized(column) + assert col_type == "varchar" + + def test_to_col_type_normalized_decimal(self): + """Should normalize DECIMAL to lowercase.""" + sql = "CREATE TABLE t (amount DECIMAL(10, 2))" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + col_type = to_col_type_normalized(column) + assert col_type == "decimal" + + +class TestGetDescription: + """Test description extraction from column comments.""" + + def test_get_description_without_comment(self): + """Should return None when no comment present.""" + sql = "CREATE TABLE t (id INT)" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + description = get_description(column) + assert description is None + + def test_get_description_with_none_comments(self): + """Should handle columns with no comments gracefully.""" + sql = "CREATE TABLE t (name VARCHAR(100), age INT)" + parsed = sqlglot.parse_one(sql, read="postgres") + columns = list(parsed.find_all(ColumnDef)) + # Both columns should have no description + for column in columns: + description = get_description(column) + assert description is None + + +class TestGetMaxLength: + """Test maximum length extraction from varchar/char types.""" + + def test_get_max_length_varchar_with_length(self): + """Should extract max length from VARCHAR.""" + sql = "CREATE TABLE t (name VARCHAR(100))" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + max_length = get_max_length(column) + assert max_length == 100 + + def test_get_max_length_char_with_length(self): + """Should extract max length from CHAR.""" + sql = "CREATE TABLE t (code CHAR(10))" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + max_length = get_max_length(column) + assert max_length == 10 + + def test_get_max_length_nvarchar(self): + """Should extract max length from NVARCHAR.""" + sql = "CREATE TABLE t (text NVARCHAR(255))" + parsed = sqlglot.parse_one(sql, read="tsql") + column = next(iter(parsed.find_all(ColumnDef))) + max_length = get_max_length(column) + assert max_length == 255 + + def test_get_max_length_nchar(self): + """Should extract max length from NCHAR.""" + sql = "CREATE TABLE t (code NCHAR(5))" + parsed = sqlglot.parse_one(sql, read="tsql") + column = next(iter(parsed.find_all(ColumnDef))) + max_length = get_max_length(column) + assert max_length == 5 + + def test_get_max_length_integer_returns_none(self): + """Should return None for non-string types.""" + sql = "CREATE TABLE t (id INT)" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + max_length = get_max_length(column) + assert max_length is None + + def test_get_max_length_varchar_without_length(self): + """Should return None for VARCHAR without length.""" + sql = "CREATE TABLE t (text VARCHAR)" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + max_length = get_max_length(column) + assert max_length is None + + +class TestGetPrecisionScale: + """Test precision and scale extraction from numeric types.""" + + def test_get_precision_scale_decimal_both(self): + """Should extract both precision and scale from DECIMAL.""" + sql = "CREATE TABLE t (amount DECIMAL(10, 2))" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + precision, scale = get_precision_scale(column) + assert precision == 10 + assert scale == 2 + + def test_get_precision_scale_numeric_both(self): + """Should extract both precision and scale from NUMERIC.""" + sql = "CREATE TABLE t (amount NUMERIC(8, 4))" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + precision, scale = get_precision_scale(column) + assert precision == 8 + assert scale == 4 + + def test_get_precision_scale_precision_only(self): + """Should extract precision and return scale as 0 when only precision given.""" + sql = "CREATE TABLE t (amount DECIMAL(10))" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + precision, scale = get_precision_scale(column) + assert precision == 10 + assert scale == 0 + + def test_get_precision_scale_number_oracle(self): + """Should extract precision and scale from NUMBER (Oracle).""" + sql = "CREATE TABLE t (amount NUMBER(15, 3))" + parsed = sqlglot.parse_one(sql, read="oracle") + column = next(iter(parsed.find_all(ColumnDef))) + precision, scale = get_precision_scale(column) + assert precision == 15 + assert scale == 3 + + def test_get_precision_scale_float(self): + """Should extract precision from FLOAT.""" + sql = "CREATE TABLE t (value FLOAT(10))" + parsed = sqlglot.parse_one(sql, read="mysql") + column = next(iter(parsed.find_all(ColumnDef))) + precision, _scale = get_precision_scale(column) + assert precision == 10 + + def test_get_precision_scale_integer_returns_none(self): + """Should return None for non-numeric types.""" + sql = "CREATE TABLE t (id INT)" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + precision, scale = get_precision_scale(column) + assert precision is None + assert scale is None + + def test_get_precision_scale_no_params(self): + """Should return None when no parameters given.""" + sql = "CREATE TABLE t (amount DECIMAL)" + parsed = sqlglot.parse_one(sql, read="postgres") + column = next(iter(parsed.find_all(ColumnDef))) + precision, scale = get_precision_scale(column) + assert precision is None + assert scale is None + + +class TestMapTypeFromSql: + """Test SQL type to ODCS logical type mapping.""" + + def test_map_type_string_types(self): + """Should map VARCHAR, CHAR, TEXT to string.""" + assert map_type_from_sql("VARCHAR(100)") == "string" + assert map_type_from_sql("CHAR(10)") == "string" + assert map_type_from_sql("TEXT") == "string" + assert map_type_from_sql("NVARCHAR(50)") == "string" + assert map_type_from_sql("NCHAR(5)") == "string" + assert map_type_from_sql("NTEXT") == "string" + assert map_type_from_sql("CLOB") == "string" + assert map_type_from_sql("NCLOB") == "string" + + def test_map_type_integer_types(self): + """Should map various integer types to integer.""" + assert map_type_from_sql("INT") == "integer" + assert map_type_from_sql("INTEGER") == "integer" + assert map_type_from_sql("BIGINT") == "integer" + assert map_type_from_sql("TINYINT") == "integer" + assert map_type_from_sql("SMALLINT") == "integer" + + def test_map_type_byteint(self): + """Should map BYTEINT (Teradata) to integer.""" + assert map_type_from_sql("BYTEINT") == "integer" + + def test_map_type_numeric_types(self): + """Should map numeric types to number.""" + assert map_type_from_sql("DECIMAL(10, 2)") == "number" + assert map_type_from_sql("NUMERIC(8, 4)") == "number" + assert map_type_from_sql("FLOAT(10)") == "number" + assert map_type_from_sql("DOUBLE") == "number" + assert map_type_from_sql("NUMBER(15, 3)") == "number" + + def test_map_type_boolean(self): + """Should map boolean types to boolean.""" + assert map_type_from_sql("BOOLEAN") == "boolean" + assert map_type_from_sql("BOOL") == "boolean" + assert map_type_from_sql("BIT") == "boolean" + + def test_map_type_date(self): + """Should map DATE to date.""" + assert map_type_from_sql("DATE") == "date" + + def test_map_type_datetime_types(self): + """Should map datetime types to date.""" + assert map_type_from_sql("DATETIME") == "date" + assert map_type_from_sql("DATETIME2") == "date" + assert map_type_from_sql("SMALLDATETIME") == "date" + assert map_type_from_sql("DATETIMEOFFSET") == "date" + assert map_type_from_sql("TIMESTAMP") == "date" + + def test_map_type_timestamp(self): + """Should map TIMESTAMP variants to date.""" + assert map_type_from_sql("TIMESTAMP") == "date" + assert map_type_from_sql("TIMESTAMP(6)") == "date" + + def test_map_type_interval_oracle(self): + """Should map INTERVAL to object for Oracle.""" + assert map_type_from_sql("INTERVAL YEAR TO MONTH", Dialects.ORACLE) == "object" + assert map_type_from_sql("INTERVAL DAY TO SECOND", Dialects.ORACLE) == "object" + + def test_map_type_interval_non_oracle(self): + """Should map INTERVAL to string for non-Oracle dialects.""" + assert map_type_from_sql("INTERVAL YEAR TO MONTH", Dialects.POSTGRES) == "string" + assert map_type_from_sql("INTERVAL DAY TO SECOND", Dialects.MYSQL) == "string" + + def test_map_type_binary_types(self): + """Should map binary types to array.""" + assert map_type_from_sql("BINARY") == "array" + assert map_type_from_sql("VARBINARY") == "array" + assert map_type_from_sql("RAW") == "array" + assert map_type_from_sql("BYTE") == "array" + assert map_type_from_sql("VARBYTE") == "array" + assert map_type_from_sql("BLOB") == "array" + assert map_type_from_sql("BFILE") == "array" + + def test_map_type_json(self): + """Should map JSON to object.""" + assert map_type_from_sql("JSON") == "object" + + def test_map_type_xml(self): + """Should map XML to string.""" + assert map_type_from_sql("XML") == "string" + + def test_map_type_unique_identifier(self): + """Should map UNIQUEIDENTIFIER to string.""" + assert map_type_from_sql("UNIQUEIDENTIFIER") == "string" + + def test_map_type_time(self): + """Should map TIME to string.""" + assert map_type_from_sql("TIME") == "string" + + def test_map_type_unknown(self): + """Should map unknown types to object.""" + assert map_type_from_sql("UNKNOWN_TYPE") == "object" + + def test_map_type_case_insensitive(self): + """Should handle case-insensitive type names.""" + assert map_type_from_sql("varchar(100)") == "string" + assert map_type_from_sql("INT") == "integer" + assert map_type_from_sql("Decimal(10,2)") == "number" + + def test_map_type_whitespace_handling(self): + """Should handle leading and trailing whitespace.""" + assert map_type_from_sql(" VARCHAR(100) ") == "string" + assert map_type_from_sql("\tINT\t") == "integer" + + +class TestReadFile: + """Test file reading functionality.""" + + def test_read_file_existing_file(self, tmp_path): + """Should read content from existing file.""" + test_file = tmp_path / "test.sql" + test_content = "CREATE TABLE test (id INT)" + test_file.write_text(test_content) + + result = read_file(str(test_file)) + assert result == test_content + + def test_read_file_nonexistent_file(self): + """Should raise DataContractException for nonexistent file.""" + with pytest.raises(DataContractException) as exc_info: + read_file("/nonexistent/path/file.sql") + + assert "does not exist" in exc_info.value.reason + + def test_read_file_empty_file(self, tmp_path): + """Should raise DataContractException for empty file.""" + test_file = tmp_path / "empty.sql" + test_file.write_text("") + + with pytest.raises(DataContractException) as exc_info: + read_file(str(test_file)) + + assert "is empty" in exc_info.value.reason + + def test_read_file_whitespace_only(self, tmp_path): + """Should raise DataContractException for whitespace-only file.""" + test_file = tmp_path / "whitespace.sql" + test_file.write_text(" \n\n \t ") + + with pytest.raises(DataContractException) as exc_info: + read_file(str(test_file)) + + assert "is empty" in exc_info.value.reason + + def test_read_file_multiline_content(self, tmp_path): + """Should read multiline content correctly.""" + test_file = tmp_path / "multiline.sql" + test_content = """CREATE TABLE users ( + id INT PRIMARY KEY, + name VARCHAR(100) +)""" + test_file.write_text(test_content) + + result = read_file(str(test_file)) + assert result == test_content