Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e94e6ec
Merge pull request #5 from datacontract/main
dmaresma Jun 11, 2025
657b68d
init. snowflake sql ddl import to datacontract
dmaresma Jun 11, 2025
a224aba
apply ruff check and format
dmaresma Jun 11, 2025
327c21a
align import
dmaresma Jun 11, 2025
234c2fb
add dialect
dmaresma Jun 11, 2025
5d412fd
sqlglot ${} token bypass and waiting for NOORDER ORDER AUTOINCREMENT …
dmaresma Jun 13, 2025
76d53b8
fix regression on sql server side (no formal or declarative comments)
dmaresma Jun 13, 2025
020d879
type variant not allow in lint DataContract(data_contract_str=expect…
dmaresma Jun 14, 2025
e2ee1e8
remove simple-ddl-parser dependency
dmaresma Jun 14, 2025
ab60f5c
Merge branch 'main' into feat/snowflake_ddl_sql_import
dmaresma Jun 29, 2025
dd2a399
fix error message
dmaresma Jul 10, 2025
6d2a8df
Merge branch 'feat/snowflake_ddl_sql_import' of https://github.com/dm…
dmaresma Jul 10, 2025
d3759c9
fix specification version in test
dmaresma Jul 10, 2025
d29d770
refactor get_model_form_parsed add table desc, table tag
dmaresma Jul 10, 2025
cff64f8
fix format issue
dmaresma Jul 10, 2025
c6bf517
Merge branch 'main' into feat/snowflake_ddl_sql_import
dmaresma Jul 10, 2025
f569c9f
Merge branch 'main' into feat/snowflake_ddl_sql_import
dmaresma Jul 11, 2025
1186bb3
add script token remover function
dmaresma Jul 27, 2025
eb718c5
Merge branch 'main' into feat/snowflake_ddl_sql_import
dmaresma Jul 28, 2025
593358c
Merge branch 'main' into feat/snowflake_ddl_sql_import
dmaresma Aug 5, 2025
1b44135
add money datatype #751
dmaresma Aug 25, 2025
29f371e
ignoe jinja
dmaresma Aug 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 133 additions & 52 deletions datacontract/imports/sql_importer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import logging
import os
import re

import sqlglot
from sqlglot.dialects.dialect import Dialects

from datacontract.imports.importer import Importer
from datacontract.model.data_contract_specification import DataContractSpecification, Field, Model, Server
from datacontract.model.data_contract_specification import (
DataContractSpecification,
Field,
Model,
Server,
)
from datacontract.model.exceptions import DataContractException
from datacontract.model.run import ResultEnum

Expand All @@ -18,16 +24,28 @@ def import_source(


def import_sql(
data_contract_specification: DataContractSpecification, format: str, source: str, import_args: dict = None
data_contract_specification: DataContractSpecification,
format: str,
source: str,
import_args: dict = None,
) -> DataContractSpecification:
dialect = to_dialect(import_args)

server_type: str | None = to_server_type(source, dialect)
if server_type is not None:
data_contract_specification.servers[server_type] = Server(type=server_type)

sql = read_file(source)

dialect = to_dialect(import_args)
parsed = None

try:
parsed = sqlglot.parse_one(sql=sql, read=dialect)
parsed = sqlglot.parse_one(sql=sql, read=dialect.lower())

tables = parsed.find_all(sqlglot.expressions.Table)

except Exception as e:
logging.error(f"Error parsing SQL: {str(e)}")
logging.error(f"Error sqlglot SQL: {str(e)}")
raise DataContractException(
type="import",
name=f"Reading source from {source}",
Expand All @@ -36,50 +54,90 @@ def import_sql(
result=ResultEnum.error,
)

server_type: str | None = to_server_type(source, dialect)
if server_type is not None:
data_contract_specification.servers[server_type] = Server(type=server_type)

tables = parsed.find_all(sqlglot.expressions.Table)

for table in tables:
if data_contract_specification.models is None:
data_contract_specification.models = {}

table_name = table.this.name

fields = {}
for column in parsed.find_all(sqlglot.exp.ColumnDef):
if column.parent.this.name != table_name:
continue

field = Field()
col_name = column.this.name
col_type = to_col_type(column, dialect)
field.type = map_type_from_sql(col_type)
col_description = get_description(column)
field.description = col_description
field.maxLength = get_max_length(column)
precision, scale = get_precision_scale(column)
field.precision = precision
field.scale = scale
field.primaryKey = get_primary_key(column)
field.required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None
physical_type_key = to_physical_type_key(dialect)
field.config = {
physical_type_key: col_type,
}

fields[col_name] = field

data_contract_specification.models[table_name] = Model(
type="table",
fields=fields,
data_contract_specification.models[table.this.name] = get_model_from_parsed(
table_name=table.this.name, parsed=parsed, dialect=dialect
)

return data_contract_specification


def get_model_from_parsed(table_name, parsed, dialect) -> Model:
table_description = None
table_tags = None

table_comment_property = parsed.find(sqlglot.expressions.SchemaCommentProperty)
if table_comment_property:
table_description = table_comment_property.this.this

prop = parsed.find(sqlglot.expressions.Properties)
if prop:
tags = prop.find(sqlglot.expressions.Tags)
if tags:
tag_enum = tags.find(sqlglot.expressions.Property)
table_tags = [str(t) for t in tag_enum]

fields = {}
for column in parsed.find_all(sqlglot.exp.ColumnDef):
if column.parent.this.name != table_name:
continue

field = Field()
col_name = column.this.name
col_type = to_col_type(column, dialect)
field.type = map_type_from_sql(col_type)
col_description = get_description(column)
field.description = col_description
field.maxLength = get_max_length(column)
precision, scale = get_precision_scale(column)
field.precision = precision
field.scale = scale
field.primaryKey = get_primary_key(column)
field.required = column.find(sqlglot.exp.NotNullColumnConstraint) is not None or None
physical_type_key = to_physical_type_key(dialect)
field.tags = get_tags(column)
field.config = {
physical_type_key: col_type,
}

fields[col_name] = field

return Model(
type="table",
description=table_description,
tags=table_tags,
fields=fields,
)


def map_physical_type(column, dialect) -> str | None:
autoincrement = ""
if column.get("autoincrement") and dialect == Dialects.SNOWFLAKE:
autoincrement = " AUTOINCREMENT" + " START " + str(column.get("start")) if column.get("start") else ""
autoincrement += " INCREMENT " + str(column.get("increment")) if column.get("increment") else ""
autoincrement += " NOORDER" if not column.get("increment_order") else ""
elif column.get("autoincrement"):
autoincrement = " IDENTITY"

if column.get("size") and isinstance(column.get("size"), tuple):
return (
column.get("type")
+ "("
+ str(column.get("size")[0])
+ ","
+ str(column.get("size")[1])
+ ")"
+ autoincrement
)
elif column.get("size"):
return column.get("type") + "(" + str(column.get("size")) + ")" + autoincrement
else:
return column.get("type") + autoincrement


def get_primary_key(column) -> bool | None:
if column.find(sqlglot.exp.PrimaryKeyColumnConstraint) is not None:
return True
Expand All @@ -100,9 +158,7 @@ def to_dialect(import_args: dict) -> Dialects | None:
return Dialects.TSQL
if dialect.upper() in Dialects.__members__:
return Dialects[dialect.upper()]
if dialect == "sqlserver":
return Dialects.TSQL
return None
return "None"


def to_physical_type_key(dialect: Dialects | str | None) -> str:
Expand Down Expand Up @@ -154,10 +210,23 @@ def to_col_type_normalized(column):

def get_description(column: sqlglot.expressions.ColumnDef) -> str | None:
if column.comments is None:
return None
description = column.find(sqlglot.expressions.CommentColumnConstraint)
if description:
return description.this.this
else:
return None
return " ".join(comment.strip() for comment in column.comments)


def get_tags(column: sqlglot.expressions.ColumnDef) -> str | None:
tags = column.find(sqlglot.expressions.Tags)
if tags:
tag_enum = tags.find(sqlglot.expressions.Property)
return [str(t) for t in tag_enum]
else:
return None


def get_max_length(column: sqlglot.expressions.ColumnDef) -> int | None:
col_type = to_col_type_normalized(column)
if col_type is None:
Expand Down Expand Up @@ -222,20 +291,22 @@ def map_type_from_sql(sql_type: str) -> str | None:
return "string"
elif sql_type_normed.startswith("int"):
return "int"
elif sql_type_normed.startswith("bigint"):
return "long"
elif sql_type_normed.startswith("tinyint"):
return "int"
elif sql_type_normed.startswith("smallint"):
return "int"
elif sql_type_normed.startswith("float"):
elif sql_type_normed.startswith("bigint"):
return "long"
elif sql_type_normed.startswith("float") or sql_type_normed.startswith("double") or sql_type_normed == "real":
return "float"
elif sql_type_normed.startswith("double"):
return "double"
elif sql_type_normed.startswith("decimal"):
elif sql_type_normed.startswith("number"):
return "decimal"
elif sql_type_normed.startswith("numeric"):
return "decimal"
elif sql_type_normed.startswith("decimal"):
return "decimal"
elif sql_type_normed.startswith("money"):
return "decimal"
elif sql_type_normed.startswith("bool"):
return "boolean"
elif sql_type_normed.startswith("bit"):
Expand All @@ -254,6 +325,7 @@ def map_type_from_sql(sql_type: str) -> str | None:
sql_type_normed == "timestamptz"
or sql_type_normed == "timestamp_tz"
or sql_type_normed == "timestamp with time zone"
or sql_type_normed == "timestamp_ltz"
):
return "timestamp_tz"
elif sql_type_normed == "timestampntz" or sql_type_normed == "timestamp_ntz":
Expand All @@ -273,7 +345,15 @@ def map_type_from_sql(sql_type: str) -> str | None:
elif sql_type_normed == "xml": # tsql
return "string"
else:
return "variant"
return "object"


def remove_variable_tokens(sql_script: str) -> str:
## to cleanse sql statement's script token like $(...) in sqlcmd for T-SQL langage, ${...} for liquibase, {{}} as Jinja
## https://learn.microsoft.com/en-us/sql/tools/sqlcmd/sqlcmd-use-scripting-variables?view=sql-server-ver17#b-use-the-setvar-command-interactively
## https://docs.liquibase.com/concepts/changelogs/property-substitution.html
## https://docs.getdbt.com/guides/using-jinja?step=1
return re.sub(r"\$\((\w+)\)|\$\{(\w+)\}|\{\{(\w+)\}\}", r"\1", sql_script)


def read_file(path):
Expand All @@ -287,4 +367,5 @@ def read_file(path):
)
with open(path, "r") as file:
file_content = file.read()
return file_content

return remove_variable_tokens(file_content)
2 changes: 1 addition & 1 deletion tests/fixtures/dbml/import/datacontract.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ models:
description: The business timestamp in UTC when the order was successfully
registered in the source system and the payment was successful.
order_total:
type: variant
type: object
required: true
primaryKey: false
unique: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ models:
description: The business timestamp in UTC when the order was successfully
registered in the source system and the payment was successful.
order_total:
type: variant
type: object
required: true
primaryKey: false
unique: false
Expand Down
42 changes: 42 additions & 0 deletions tests/fixtures/snowflake/import/ddl.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
CREATE TABLE IF NOT EXISTS ${database_name}.PUBLIC.my_table (
-- https://docs.snowflake.com/en/sql-reference/intro-summary-data-types
field_primary_key NUMBER(38,0) NOT NULL autoincrement start 1 increment 1 COMMENT 'Primary key',
field_not_null INT NOT NULL COMMENT 'Not null',
field_char CHAR(10) COMMENT 'Fixed-length string',
field_character CHARACTER(10) COMMENT 'Fixed-length string',
field_varchar VARCHAR(100) WITH TAG (SNOWFLAKE.CORE.PRIVACY_CATEGORY='IDENTIFIER', SNOWFLAKE.CORE.SEMANTIC_CATEGORY='NAME') COMMENT 'Variable-length string',

field_text TEXT COMMENT 'Large variable-length string',
field_string STRING COMMENT 'Large variable-length Unicode string',

field_tinyint TINYINT COMMENT 'Integer (0-255)',
field_smallint SMALLINT COMMENT 'Integer (-32,768 to 32,767)',
field_int INT COMMENT 'Integer (-2.1B to 2.1B)',
field_integer INTEGER COMMENT 'Integer full name(-2.1B to 2.1B)',
field_bigint BIGINT COMMENT 'Large integer (-9 quintillion to 9 quintillion)',

field_decimal DECIMAL(10, 2) COMMENT 'Fixed precision decimal',
field_numeric NUMERIC(10, 2) COMMENT 'Same as DECIMAL',

field_float FLOAT COMMENT 'Approximate floating-point',
field_float4 FLOAT4 COMMENT 'Approximate floating-point 4',
field_float8 FLOAT8 COMMENT 'Approximate floating-point 8',
field_real REAL COMMENT 'Smaller floating-point',

field_boulean BOOLEAN COMMENT 'Boolean-like (0 or 1)',

field_date DATE COMMENT 'Date only (YYYY-MM-DD)',
field_time TIME COMMENT 'Time only (HH:MM:SS)',
field_timestamp TIMESTAMP COMMENT 'More precise datetime',
field_timestamp_ltz TIMESTAMP_LTZ COMMENT 'More precise datetime with local time zone; time zone, if provided, isn`t stored.',
field_timestamp_ntz TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP() COMMENT 'More precise datetime with no time zone; time zone, if provided, isn`t stored.',
field_timestamp_tz TIMESTAMP_TZ COMMENT 'More precise datetime with time zone.',

field_binary BINARY(16) COMMENT 'Fixed-length binary',
field_varbinary VARBINARY(100) COMMENT 'Variable-length binary',

field_variant VARIANT COMMENT 'VARIANT data',
field_json OBJECT COMMENT 'JSON (Stored as text)',
UNIQUE(field_not_null),
PRIMARY KEY (field_primary_key)
) COMMENT = 'My Comment'
2 changes: 2 additions & 0 deletions tests/test_import_sql_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def test_cli():
"sql",
"--source",
sql_file_path,
"--dialect",
"postgres"
],
)
assert result.exit_code == 0
Expand Down
Loading
Loading