Skip to content

Commit 190f5d2

Browse files
committed
Add postgres support
1 parent 6688775 commit 190f5d2

File tree

7 files changed

+196
-1
lines changed

7 files changed

+196
-1
lines changed

deploy_ai_search/src/deploy_ai_search/text_2_sql_schema_store.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def excluded_fields_for_database_engine(self):
6565
engine_specific_fields = ["Database"]
6666
elif self.database_engine == DatabaseEngine.DATABRICKS:
6767
engine_specific_fields = ["Catalog"]
68+
elif self.database_engine == DatabaseEngine.POSTGRESQL:
69+
engine_specific_fields = ["Database"]
6870

6971
return [
7072
field

text_2_sql/autogen/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies = [
1111
"autogen-ext[azure,openai]==0.4.0.dev11",
1212
"grpcio>=1.68.1",
1313
"pyyaml>=6.0.2",
14-
"text_2_sql_core[snowflake,databricks]",
14+
"text_2_sql_core[snowflake,databricks,postgresql]",
1515
]
1616

1717
[dependency-groups]

text_2_sql/text_2_sql_core/pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ databricks = [
4646
"databricks-sql-connector>=3.0.1",
4747
"pyarrow>=14.0.2,<17",
4848
]
49+
postgresql = [
50+
"psycopg>=3.2.3",
51+
]
52+
4953

5054
[build-system]
5155
requires = ["hatchling"]
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from text_2_sql_core.connectors.sql import SqlConnector
4+
import psycopg
5+
from typing import Annotated
6+
import os
7+
import logging
8+
import json
9+
10+
from text_2_sql_core.utils.database import DatabaseEngine
11+
12+
13+
class PostgresqlSqlConnector(SqlConnector):
14+
def __init__(self):
15+
super().__init__()
16+
17+
self.database_engine = DatabaseEngine.POSTGRESQL
18+
19+
async def query_execution(
20+
self,
21+
sql_query: Annotated[str, "The SQL query to run against the database."],
22+
cast_to: any = None,
23+
limit=None,
24+
) -> list[dict]:
25+
"""Run the SQL query against the PostgreSQL database asynchronously.
26+
27+
Args:
28+
----
29+
sql_query (str): The SQL query to run against the database.
30+
31+
Returns:
32+
-------
33+
list[dict]: The results of the SQL query.
34+
"""
35+
logging.info(f"Running query: {sql_query}")
36+
results = []
37+
connection_string = os.environ["Text2Sql__DatabaseConnectionString"]
38+
39+
# Establish an asynchronous connection to the PostgreSQL database
40+
async with psycopg.AsyncConnection.connect(connection_string) as conn:
41+
# Create an asynchronous cursor
42+
async with conn.cursor() as cursor:
43+
await cursor.execute(sql_query)
44+
45+
# Fetch column names
46+
columns = [column[0] for column in cursor.description]
47+
48+
# Fetch rows based on the limit
49+
if limit is not None:
50+
rows = await cursor.fetchmany(limit)
51+
else:
52+
rows = await cursor.fetchall()
53+
54+
# Process the rows
55+
for row in rows:
56+
if cast_to:
57+
results.append(cast_to.from_sql_row(row, columns))
58+
else:
59+
results.append(dict(zip(columns, row)))
60+
61+
logging.debug("Results: %s", results)
62+
return results
63+
64+
async def get_entity_schemas(
65+
self,
66+
text: Annotated[
67+
str,
68+
"The text to run a semantic search against. Relevant entities will be returned.",
69+
],
70+
excluded_entities: Annotated[
71+
list[str],
72+
"The entities to exclude from the search results. Pass the entity property of entities (e.g. 'SalesLT.Address') you already have the schemas for to avoid getting repeated entities.",
73+
] = [],
74+
as_json: bool = True,
75+
) -> str:
76+
"""Gets the schema of a view or table in the SQL Database by selecting the most relevant entity based on the search term. Several entities may be returned.
77+
78+
Args:
79+
----
80+
text (str): The text to run the search against.
81+
82+
Returns:
83+
str: The schema of the views or tables in JSON format.
84+
"""
85+
86+
schemas = await self.ai_search_connector.get_entity_schemas(
87+
text, excluded_entities
88+
)
89+
90+
for schema in schemas:
91+
schema["SelectFromEntity"] = ".".join([schema["Schema"], schema["Entity"]])
92+
93+
del schema["Entity"]
94+
del schema["Schema"]
95+
96+
if as_json:
97+
return json.dumps(schemas, default=str)
98+
else:
99+
return schemas

text_2_sql/text_2_sql_core/src/text_2_sql_core/data_dictionary/data_dictionary_creator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,8 @@ def excluded_fields_for_database_engine(self):
758758
engine_specific_fields = ["Database"]
759759
elif self.database_engine == DatabaseEngine.DATABRICKS:
760760
engine_specific_fields = ["Catalog"]
761+
elif self.database_engine == DatabaseEngine.POSTGRESQL:
762+
engine_specific_fields = ["Database"]
761763
else:
762764
engine_specific_fields = []
763765

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from text_2_sql_core.data_dictionary.data_dictionary_creator import (
4+
DataDictionaryCreator,
5+
EntityItem,
6+
)
7+
import os
8+
9+
from text_2_sql_core.utils.database import DatabaseEngine
10+
from text_2_sql_core.connectors.postgresql_sql import PostgresSqlConnector
11+
12+
13+
class PostgresqlDataDictionaryCreator(DataDictionaryCreator):
14+
def __init__(self, **kwargs):
15+
"""A method to initialize the DataDictionaryCreator class."""
16+
super().__init__(**kwargs)
17+
18+
self.database = os.environ["Text2Sql__DatabaseName"]
19+
self.database_engine = DatabaseEngine.POSTGRESQL
20+
21+
self.sql_connector = PostgresSqlConnector()
22+
23+
@property
24+
def extract_table_entities_sql_query(self) -> str:
25+
"""A property to extract table entities from a PostgreSQL database."""
26+
return """SELECT
27+
t.table_name AS entity,
28+
t.table_schema AS entity_schema,
29+
pg_catalog.col_description(c.oid, 0) AS definition
30+
FROM
31+
information_schema.tables t
32+
JOIN
33+
pg_catalog.pg_class c ON c.relname = t.table_name
34+
AND c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = t.table_schema)
35+
WHERE
36+
t.table_type = 'BASE TABLE'
37+
ORDER BY entity_schema, entity;"""
38+
39+
@property
40+
def extract_view_entities_sql_query(self) -> str:
41+
"""A property to extract view entities from a PostgreSQL database."""
42+
return """SELECT
43+
v.view_name AS entity,
44+
v.table_schema AS entity_schema,
45+
pg_catalog.col_description(c.oid, 0) AS definition
46+
FROM
47+
information_schema.views v
48+
JOIN
49+
pg_catalog.pg_class c ON c.relname = v.view_name
50+
AND c.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = v.table_schema)
51+
ORDER BY entity_schema, entity;"""
52+
53+
def extract_columns_sql_query(self, entity: EntityItem) -> str:
54+
"""A property to extract column information from a PostgreSQL database."""
55+
return f"""SELECT
56+
c.column_name AS name,
57+
c.data_type AS data_type,
58+
pg_catalog.col_description(t.oid, c.ordinal_position) AS definition
59+
FROM
60+
information_schema.columns c
61+
JOIN
62+
pg_catalog.pg_class t ON t.relname = c.table_name
63+
AND t.relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)
64+
WHERE
65+
c.table_schema = '{entity.entity_schema}'
66+
AND c.table_name = '{entity.name}'
67+
ORDER BY c.ordinal_position;"""
68+
69+
@property
70+
def extract_entity_relationships_sql_query(self) -> str:
71+
"""A property to extract entity relationships from a PostgreSQL database."""
72+
return """SELECT
73+
tc.table_schema AS entity_schema,
74+
tc.table_name AS entity,
75+
rc.unique_constraint_schema AS foreign_entity_schema,
76+
rc.unique_constraint_name AS foreign_entity_constraint,
77+
rc.constraint_name AS foreign_key_constraint
78+
FROM
79+
information_schema.referential_constraints rc
80+
JOIN
81+
information_schema.table_constraints tc
82+
ON rc.constraint_schema = tc.constraint_schema
83+
AND rc.constraint_name = tc.constraint_name
84+
WHERE
85+
tc.constraint_type = 'FOREIGN KEY'
86+
ORDER BY
87+
entity_schema, entity, foreign_entity_schema, foreign_entity_constraint;"""

text_2_sql/text_2_sql_core/src/text_2_sql_core/utils/database.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ class DatabaseEngine(StrEnum):
77
DATABRICKS = "DATABRICKS"
88
SNOWFLAKE = "SNOWFLAKE"
99
TSQL = "TSQL"
10+
POSTGRESQL = "POSTGRESQL"

0 commit comments

Comments
 (0)