Skip to content

Commit 4284344

Browse files
authored
feat(codegen): add text2sql views code generation (#40)
1 parent e0305ff commit 4284344

File tree

5 files changed

+897
-6
lines changed

5 files changed

+897
-6
lines changed

src/dbally_cli/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import click
22

3+
from dbally_cli.text2sql import generate_text2sql_view
4+
35

46
@click.group()
57
def cli() -> None:
68
"""
79
Command line tool for interacting with dbally.
810
"""
11+
12+
13+
cli.add_command(generate_text2sql_view)

src/dbally_cli/text2sql.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import asyncio
2+
import importlib
3+
import os
4+
from typing import Callable, Optional, Union
5+
6+
import click
7+
import sqlalchemy
8+
from click import BadParameter, Context, Option
9+
from sqlalchemy import Column, Engine, Table
10+
from sqlalchemy.exc import ArgumentError
11+
12+
from dbally.embeddings.litellm import LiteLLMEmbeddingClient
13+
from dbally.llms.base import LLM
14+
from dbally.llms.litellm import LiteLLM
15+
from dbally.similarity.faiss_store import FaissStore
16+
from dbally.similarity.index import SimilarityIndex
17+
from dbally.similarity.sqlalchemy_base import SimpleSqlAlchemyFetcher
18+
from dbally_codegen.autodiscovery import configure_text2sql_auto_discovery
19+
from dbally_codegen.generator import Text2SQLViewGenerator
20+
21+
22+
def faiss_builder(engine: sqlalchemy.Engine, table: sqlalchemy.Table, column: sqlalchemy.Column) -> SimilarityIndex:
23+
"""
24+
Build a Faiss store.
25+
26+
Args:
27+
engine: The SQLAlchemy engine.
28+
table: The table.
29+
column: The column.
30+
31+
Returns:
32+
The Faiss store.
33+
"""
34+
return SimilarityIndex(
35+
fetcher=SimpleSqlAlchemyFetcher(
36+
sqlalchemy_engine=engine,
37+
column=column,
38+
table=table,
39+
),
40+
store=FaissStore(
41+
index_dir=".",
42+
index_name=f"{table.name}_{column.name}_index",
43+
embedding_client=LiteLLMEmbeddingClient(),
44+
),
45+
)
46+
47+
48+
def validate_file_path(_ctx: Context, _param: Option, value: str) -> str:
49+
"""
50+
Validate the file path.
51+
52+
Args:
53+
value: The value of the option.
54+
55+
Returns:
56+
The validated file path.
57+
"""
58+
root, ext = os.path.splitext(value)
59+
if not ext:
60+
ext = ".py"
61+
elif ext != ".py":
62+
raise BadParameter("file extension must be '.py'.")
63+
return f"{root}{ext}"
64+
65+
66+
def validate_db_url(_ctx: Context, _param: Option, value: Union[str, Engine]) -> Engine:
67+
"""
68+
Validate the database connection string.
69+
70+
Args:
71+
value: The value of the option.
72+
73+
Returns:
74+
The validated database connection string.
75+
"""
76+
if isinstance(value, Engine):
77+
return value
78+
if not value:
79+
raise BadParameter("database connection string is required.")
80+
81+
try:
82+
return sqlalchemy.create_engine(value)
83+
except ArgumentError as exc:
84+
raise BadParameter("invalid database connection string.") from exc
85+
86+
87+
def validate_llm_object(_ctx: Context, _param: Option, value: Union[str, LLM]) -> Optional[LLM]:
88+
""" "
89+
Validate the LLM object.
90+
91+
Args:
92+
value: The value of the option.
93+
94+
Returns:
95+
The validated LLM object.
96+
"""
97+
if isinstance(value, LLM):
98+
return value
99+
if value == "None" or value is None:
100+
return None
101+
if value.startswith("litellm:"):
102+
return LiteLLM(value.split(":")[1])
103+
104+
llm = load_object(value)
105+
if not isinstance(llm, LLM):
106+
raise BadParameter("The LLM object must be an instance of the LLM class.")
107+
return llm
108+
109+
110+
def validate_similarity_index_factory(
111+
_ctx: Context, _param: Option, value: Union[str, Callable[[Engine, Table, Column], SimilarityIndex]]
112+
) -> Optional[Callable[[Engine, Table, Column], SimilarityIndex]]:
113+
"""
114+
Validate the similarity index factory.
115+
116+
Args:
117+
value: The value of the option.
118+
119+
Returns:
120+
The validated similarity index factory.
121+
"""
122+
if callable(value):
123+
return value
124+
if value == "None" or value is None:
125+
return None
126+
if value == "faiss":
127+
return faiss_builder
128+
129+
index_builder = load_object(value) if value else None
130+
if not callable(index_builder):
131+
raise BadParameter("The similarity index factory must be a callable object.")
132+
return index_builder
133+
134+
135+
@click.command(help="Generate a Text2SQL view definition file.")
136+
@click.option(
137+
"--file_path",
138+
default="text2sql_view.py",
139+
show_default=True,
140+
prompt="File path",
141+
help="The path to the file where the view will be generated.",
142+
callback=validate_file_path,
143+
)
144+
@click.option(
145+
"--db",
146+
default="sqlite://",
147+
show_default=True,
148+
prompt="Database URL",
149+
help="The database connection string.",
150+
callback=validate_db_url,
151+
type=click.UNPROCESSED,
152+
)
153+
@click.option(
154+
"--llm",
155+
default="None",
156+
show_default=True,
157+
prompt="LLM object",
158+
help="The path to the LLM object.",
159+
callback=validate_llm_object,
160+
type=click.UNPROCESSED,
161+
)
162+
@click.option(
163+
"--llm_description",
164+
is_flag=True,
165+
default=False,
166+
show_default=True,
167+
prompt="LLM table description?",
168+
help="Generate tables description using LLM.",
169+
)
170+
@click.option(
171+
"--similarity_index_factory",
172+
default="None",
173+
show_default=True,
174+
prompt="Similarity index factory",
175+
help="The path to the similarity index factory.",
176+
callback=validate_similarity_index_factory,
177+
type=click.UNPROCESSED,
178+
)
179+
def generate_text2sql_view(
180+
file_path: str,
181+
db: Engine,
182+
llm: Optional[LLM],
183+
llm_description: bool,
184+
similarity_index_factory: Optional[Callable[[Engine, sqlalchemy.Table, sqlalchemy.Column], SimilarityIndex]],
185+
) -> None:
186+
"""
187+
Generate a Text2SQL view definition file.
188+
189+
Args:
190+
file_path: The path to the file where the view will be generated.
191+
db: The database connection string.
192+
llm: The path to the LLM object.
193+
llm_description: Generate a description using the LLM object.
194+
similarity_index_factory: The path to the similarity index factory.
195+
"""
196+
builder = configure_text2sql_auto_discovery(db)
197+
if llm:
198+
builder = builder.use_llm(llm)
199+
if llm_description:
200+
builder = builder.generate_description_by_llm()
201+
if similarity_index_factory:
202+
builder = builder.suggest_similarity_indexes(similarity_index_factory)
203+
204+
click.echo("Discovering tables...")
205+
tables = asyncio.run(builder.discover())
206+
click.echo(f"Discovered {len(tables)} tables.")
207+
208+
click.echo("Generating Text2SQL view...")
209+
generator = Text2SQLViewGenerator(tables)
210+
code = generator.generate()
211+
212+
dirs, _ = os.path.split(file_path)
213+
if dirs:
214+
os.makedirs(dirs, exist_ok=True)
215+
216+
with open(file_path, "w", encoding="utf-8") as file:
217+
file.write(code)
218+
219+
click.echo(f"Generated Text2SQL view in {file_path}.")
220+
221+
222+
def load_object(path: str) -> object:
223+
"""
224+
Load an object from a module.
225+
226+
Args:
227+
path: The path to the object in the format 'module:object'.
228+
229+
Returns:
230+
The object.
231+
232+
Raises:
233+
BadParameter: If the object is not found.
234+
"""
235+
try:
236+
module_name, object_name = path.split(":")
237+
except ValueError as exc:
238+
raise BadParameter("The object must be in the format 'module:object'.") from exc
239+
240+
try:
241+
module = importlib.import_module(module_name)
242+
except ModuleNotFoundError as exc:
243+
raise BadParameter(f"Could not find the module '{module_name}'.") from exc
244+
245+
try:
246+
return getattr(module, object_name)
247+
except AttributeError as exc:
248+
raise BadParameter(f"Could not find the '{object_name}' object in the '{module_name}' module.") from exc

src/dbally_codegen/autodiscovery.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from dbally.llms.base import LLM
99
from dbally.prompts import PromptTemplate
1010
from dbally.similarity.index import SimilarityIndex
11-
from dbally.similarity.store import SimilarityStore
1211
from dbally.views.freeform.text2sql import ColumnConfig, TableConfig
1312

1413
DISCOVERY_TEMPLATE = PromptTemplate(
@@ -141,7 +140,7 @@ async def select_index(
141140
column: Column,
142141
description: str,
143142
connection: Connection,
144-
) -> Optional[SimilarityStore]:
143+
) -> Optional[SimilarityIndex]:
145144
"""
146145
Select the similarity index for the column.
147146
@@ -167,7 +166,7 @@ async def select_index(
167166
column: Column,
168167
description: str,
169168
connection: Connection,
170-
) -> Optional[SimilarityStore]:
169+
) -> Optional[SimilarityIndex]:
171170
"""
172171
Select the similarity index for the column.
173172
@@ -201,7 +200,7 @@ async def select_index(
201200
column: Column,
202201
description: str,
203202
connection: Connection,
204-
) -> Optional[SimilarityStore]:
203+
) -> Optional[SimilarityIndex]:
205204
"""
206205
Select the similarity index for the column using LLM.
207206
@@ -333,14 +332,14 @@ async def discover(self) -> List[TableConfig]:
333332
)
334333
columns.append(
335334
ColumnConfig(
336-
name=column.name,
335+
name=str(column.name),
337336
data_type=str(column.type),
338337
similarity_index=similarity_index,
339338
)
340339
)
341340
tables.append(
342341
TableConfig(
343-
name=table.name,
342+
name=str(table.name),
344343
description=description,
345344
columns=columns,
346345
)

0 commit comments

Comments
 (0)