Skip to content

Commit 55d8c75

Browse files
test metadata mappings
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent b0b58fb commit 55d8c75

File tree

2 files changed

+186
-0
lines changed

2 files changed

+186
-0
lines changed
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""
2+
Tests for SEA metadata column mappings and normalization.
3+
"""
4+
5+
import pytest
6+
from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings
7+
from databricks.sql.backend.sea.utils.result_column import ResultColumn
8+
from databricks.sql.backend.sea.utils.conversion import SqlType
9+
10+
11+
class TestMetadataColumnMappings:
12+
"""Test suite for metadata column mappings."""
13+
14+
def test_result_column_creation(self):
15+
"""Test ResultColumn data class creation and attributes."""
16+
col = ResultColumn("TABLE_CAT", "catalog", SqlType.STRING)
17+
assert col.thrift_col_name == "TABLE_CAT"
18+
assert col.sea_col_name == "catalog"
19+
assert col.thrift_col_type == SqlType.STRING
20+
21+
def test_result_column_with_none_sea_name(self):
22+
"""Test ResultColumn when SEA doesn't return this column."""
23+
col = ResultColumn("TYPE_CAT", None, SqlType.STRING)
24+
assert col.thrift_col_name == "TYPE_CAT"
25+
assert col.sea_col_name is None
26+
assert col.thrift_col_type == SqlType.STRING
27+
28+
def test_catalog_columns_mapping(self):
29+
"""Test catalog columns mapping for getCatalogs."""
30+
catalog_cols = MetadataColumnMappings.CATALOG_COLUMNS
31+
assert len(catalog_cols) == 1
32+
33+
catalog_col = catalog_cols[0]
34+
assert catalog_col.thrift_col_name == "TABLE_CAT"
35+
assert catalog_col.sea_col_name == "catalog"
36+
assert catalog_col.thrift_col_type == SqlType.STRING
37+
38+
def test_schema_columns_mapping(self):
39+
"""Test schema columns mapping for getSchemas."""
40+
schema_cols = MetadataColumnMappings.SCHEMA_COLUMNS
41+
assert len(schema_cols) == 2
42+
43+
# Check TABLE_SCHEM column
44+
schema_col = schema_cols[0]
45+
assert schema_col.thrift_col_name == "TABLE_SCHEM"
46+
assert schema_col.sea_col_name == "databaseName"
47+
assert schema_col.thrift_col_type == SqlType.STRING
48+
49+
# Check TABLE_CATALOG column
50+
catalog_col = schema_cols[1]
51+
assert catalog_col.thrift_col_name == "TABLE_CATALOG"
52+
assert catalog_col.sea_col_name is None
53+
assert catalog_col.thrift_col_type == SqlType.STRING
54+
55+
def test_table_columns_mapping(self):
56+
"""Test table columns mapping for getTables."""
57+
table_cols = MetadataColumnMappings.TABLE_COLUMNS
58+
assert len(table_cols) == 10
59+
60+
# Test key columns
61+
expected_mappings = [
62+
("TABLE_CAT", "catalogName", SqlType.STRING),
63+
("TABLE_SCHEM", "namespace", SqlType.STRING),
64+
("TABLE_NAME", "tableName", SqlType.STRING),
65+
("TABLE_TYPE", "tableType", SqlType.STRING),
66+
("REMARKS", "remarks", SqlType.STRING),
67+
("TYPE_CAT", None, SqlType.STRING),
68+
("TYPE_SCHEM", None, SqlType.STRING),
69+
("TYPE_NAME", None, SqlType.STRING),
70+
("SELF_REFERENCING_COL_NAME", None, SqlType.STRING),
71+
("REF_GENERATION", None, SqlType.STRING),
72+
]
73+
74+
for i, (thrift_name, sea_name, sql_type) in enumerate(expected_mappings):
75+
col = table_cols[i]
76+
assert col.thrift_col_name == thrift_name
77+
assert col.sea_col_name == sea_name
78+
assert col.thrift_col_type == sql_type
79+
80+
def test_column_columns_mapping(self):
81+
"""Test column columns mapping for getColumns."""
82+
column_cols = MetadataColumnMappings.COLUMN_COLUMNS
83+
# Should have 23 columns (not including IS_GENERATED_COLUMN)
84+
assert len(column_cols) == 23
85+
86+
# Test some key columns
87+
key_columns = {
88+
"TABLE_CAT": ("catalogName", SqlType.STRING),
89+
"TABLE_SCHEM": ("namespace", SqlType.STRING),
90+
"TABLE_NAME": ("tableName", SqlType.STRING),
91+
"COLUMN_NAME": ("col_name", SqlType.STRING),
92+
"DATA_TYPE": (None, SqlType.INT),
93+
"TYPE_NAME": ("columnType", SqlType.STRING),
94+
"COLUMN_SIZE": ("columnSize", SqlType.INT),
95+
"DECIMAL_DIGITS": ("decimalDigits", SqlType.INT),
96+
"NUM_PREC_RADIX": ("radix", SqlType.INT),
97+
"ORDINAL_POSITION": ("ordinalPosition", SqlType.INT),
98+
"IS_NULLABLE": ("isNullable", SqlType.STRING),
99+
"IS_AUTOINCREMENT": ("isAutoIncrement", SqlType.STRING),
100+
}
101+
102+
for col in column_cols:
103+
if col.thrift_col_name in key_columns:
104+
expected_sea_name, expected_type = key_columns[col.thrift_col_name]
105+
assert col.sea_col_name == expected_sea_name
106+
assert col.thrift_col_type == expected_type
107+
108+
def test_is_generated_column_not_included(self):
109+
"""Test that IS_GENERATED_COLUMN is not included in COLUMN_COLUMNS."""
110+
column_names = [
111+
col.thrift_col_name for col in MetadataColumnMappings.COLUMN_COLUMNS
112+
]
113+
assert "IS_GENERATEDCOLUMN" not in column_names
114+
115+
def test_column_type_consistency(self):
116+
"""Test that column types are consistent with JDBC spec."""
117+
# Test numeric types
118+
assert MetadataColumnMappings.DATA_TYPE_COLUMN.thrift_col_type == SqlType.INT
119+
assert MetadataColumnMappings.COLUMN_SIZE_COLUMN.thrift_col_type == SqlType.INT
120+
assert (
121+
MetadataColumnMappings.BUFFER_LENGTH_COLUMN.thrift_col_type
122+
== SqlType.TINYINT
123+
)
124+
assert (
125+
MetadataColumnMappings.DECIMAL_DIGITS_COLUMN.thrift_col_type == SqlType.INT
126+
)
127+
assert (
128+
MetadataColumnMappings.NUM_PREC_RADIX_COLUMN.thrift_col_type == SqlType.INT
129+
)
130+
assert (
131+
MetadataColumnMappings.ORDINAL_POSITION_COLUMN.thrift_col_type
132+
== SqlType.INT
133+
)
134+
assert MetadataColumnMappings.NULLABLE_COLUMN.thrift_col_type == SqlType.INT
135+
assert (
136+
MetadataColumnMappings.SQL_DATA_TYPE_COLUMN.thrift_col_type == SqlType.INT
137+
)
138+
assert (
139+
MetadataColumnMappings.SQL_DATETIME_SUB_COLUMN.thrift_col_type
140+
== SqlType.INT
141+
)
142+
assert (
143+
MetadataColumnMappings.CHAR_OCTET_LENGTH_COLUMN.thrift_col_type
144+
== SqlType.INT
145+
)
146+
assert (
147+
MetadataColumnMappings.SOURCE_DATA_TYPE_COLUMN.thrift_col_type
148+
== SqlType.SMALLINT
149+
)
150+
151+
# Test string types
152+
assert MetadataColumnMappings.CATALOG_COLUMN.thrift_col_type == SqlType.STRING
153+
assert MetadataColumnMappings.SCHEMA_COLUMN.thrift_col_type == SqlType.STRING
154+
assert (
155+
MetadataColumnMappings.TABLE_NAME_COLUMN.thrift_col_type == SqlType.STRING
156+
)
157+
assert (
158+
MetadataColumnMappings.IS_NULLABLE_COLUMN.thrift_col_type == SqlType.STRING
159+
)
160+
assert (
161+
MetadataColumnMappings.IS_AUTO_INCREMENT_COLUMN.thrift_col_type
162+
== SqlType.STRING
163+
)

tests/unit/test_sea_backend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from databricks.sql.backend.sea.models.base import ServiceError, StatementStatus
1616
from databricks.sql.backend.sea.result_set import SeaResultSet
17+
from databricks.sql.backend.sea.utils.metadata_mappings import MetadataColumnMappings
1718
from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType
1819
from databricks.sql.parameters.native import IntegerParameter, TDbsqlParameter
1920
from databricks.sql.thrift_api.TCLIService import ttypes
@@ -756,6 +757,11 @@ def test_get_catalogs(self, sea_client, sea_session_id, mock_cursor):
756757
# Verify the result is correct
757758
assert result == mock_result_set
758759

760+
# Verify prepare_metadata_columns was called
761+
mock_result_set.prepare_metadata_columns.assert_called_once_with(
762+
MetadataColumnMappings.CATALOG_COLUMNS
763+
)
764+
759765
def test_get_schemas(self, sea_client, sea_session_id, mock_cursor):
760766
"""Test the get_schemas method with various parameter combinations."""
761767
# Mock the execute_command method
@@ -818,6 +824,12 @@ def test_get_schemas(self, sea_client, sea_session_id, mock_cursor):
818824
)
819825
assert "Catalog name is required for get_schemas" in str(excinfo.value)
820826

827+
# Verify prepare_metadata_columns was called for successful cases
828+
assert mock_result_set.prepare_metadata_columns.call_count == 2
829+
mock_result_set.prepare_metadata_columns.assert_called_with(
830+
MetadataColumnMappings.SCHEMA_COLUMNS
831+
)
832+
821833
def test_get_tables(self, sea_client, sea_session_id, mock_cursor):
822834
"""Test the get_tables method with various parameter combinations."""
823835
# Mock the execute_command method
@@ -905,6 +917,11 @@ def test_get_tables(self, sea_client, sea_session_id, mock_cursor):
905917
enforce_embedded_schema_correctness=False,
906918
)
907919

920+
# Verify prepare_metadata_columns was called
921+
mock_result_set.prepare_metadata_columns.assert_called_with(
922+
MetadataColumnMappings.TABLE_COLUMNS
923+
)
924+
908925
def test_get_columns(self, sea_client, sea_session_id, mock_cursor):
909926
"""Test the get_columns method with various parameter combinations."""
910927
# Mock the execute_command method
@@ -969,6 +986,12 @@ def test_get_columns(self, sea_client, sea_session_id, mock_cursor):
969986
)
970987
assert "Catalog name is required for get_columns" in str(excinfo.value)
971988

989+
# Verify prepare_metadata_columns was called for successful cases
990+
assert mock_result_set.prepare_metadata_columns.call_count == 2
991+
mock_result_set.prepare_metadata_columns.assert_called_with(
992+
MetadataColumnMappings.COLUMN_COLUMNS
993+
)
994+
972995
def test_get_tables_with_cloud_fetch(
973996
self, sea_client_cloud_fetch, sea_session_id, mock_cursor
974997
):

0 commit comments

Comments
 (0)