|
| 1 | +import asyncio |
| 2 | +import importlib |
| 3 | +import os |
| 4 | +from typing import Callable, Optional, Union |
| 5 | + |
| 6 | +import click |
| 7 | +import sqlalchemy |
| 8 | +from click import BadParameter, Context, Option |
| 9 | +from sqlalchemy import Column, Engine, Table |
| 10 | +from sqlalchemy.exc import ArgumentError |
| 11 | + |
| 12 | +from dbally.embeddings.litellm import LiteLLMEmbeddingClient |
| 13 | +from dbally.llms.base import LLM |
| 14 | +from dbally.llms.litellm import LiteLLM |
| 15 | +from dbally.similarity.faiss_store import FaissStore |
| 16 | +from dbally.similarity.index import SimilarityIndex |
| 17 | +from dbally.similarity.sqlalchemy_base import SimpleSqlAlchemyFetcher |
| 18 | +from dbally_codegen.autodiscovery import configure_text2sql_auto_discovery |
| 19 | +from dbally_codegen.generator import Text2SQLViewGenerator |
| 20 | + |
| 21 | + |
| 22 | +def faiss_builder(engine: sqlalchemy.Engine, table: sqlalchemy.Table, column: sqlalchemy.Column) -> SimilarityIndex: |
| 23 | + """ |
| 24 | + Build a Faiss store. |
| 25 | +
|
| 26 | + Args: |
| 27 | + engine: The SQLAlchemy engine. |
| 28 | + table: The table. |
| 29 | + column: The column. |
| 30 | +
|
| 31 | + Returns: |
| 32 | + The Faiss store. |
| 33 | + """ |
| 34 | + return SimilarityIndex( |
| 35 | + fetcher=SimpleSqlAlchemyFetcher( |
| 36 | + sqlalchemy_engine=engine, |
| 37 | + column=column, |
| 38 | + table=table, |
| 39 | + ), |
| 40 | + store=FaissStore( |
| 41 | + index_dir=".", |
| 42 | + index_name=f"{table.name}_{column.name}_index", |
| 43 | + embedding_client=LiteLLMEmbeddingClient(), |
| 44 | + ), |
| 45 | + ) |
| 46 | + |
| 47 | + |
| 48 | +def validate_file_path(_ctx: Context, _param: Option, value: str) -> str: |
| 49 | + """ |
| 50 | + Validate the file path. |
| 51 | +
|
| 52 | + Args: |
| 53 | + value: The value of the option. |
| 54 | +
|
| 55 | + Returns: |
| 56 | + The validated file path. |
| 57 | + """ |
| 58 | + root, ext = os.path.splitext(value) |
| 59 | + if not ext: |
| 60 | + ext = ".py" |
| 61 | + elif ext != ".py": |
| 62 | + raise BadParameter("file extension must be '.py'.") |
| 63 | + return f"{root}{ext}" |
| 64 | + |
| 65 | + |
| 66 | +def validate_db_url(_ctx: Context, _param: Option, value: Union[str, Engine]) -> Engine: |
| 67 | + """ |
| 68 | + Validate the database connection string. |
| 69 | +
|
| 70 | + Args: |
| 71 | + value: The value of the option. |
| 72 | +
|
| 73 | + Returns: |
| 74 | + The validated database connection string. |
| 75 | + """ |
| 76 | + if isinstance(value, Engine): |
| 77 | + return value |
| 78 | + if not value: |
| 79 | + raise BadParameter("database connection string is required.") |
| 80 | + |
| 81 | + try: |
| 82 | + return sqlalchemy.create_engine(value) |
| 83 | + except ArgumentError as exc: |
| 84 | + raise BadParameter("invalid database connection string.") from exc |
| 85 | + |
| 86 | + |
| 87 | +def validate_llm_object(_ctx: Context, _param: Option, value: Union[str, LLM]) -> Optional[LLM]: |
| 88 | + """ " |
| 89 | + Validate the LLM object. |
| 90 | +
|
| 91 | + Args: |
| 92 | + value: The value of the option. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + The validated LLM object. |
| 96 | + """ |
| 97 | + if isinstance(value, LLM): |
| 98 | + return value |
| 99 | + if value == "None" or value is None: |
| 100 | + return None |
| 101 | + if value.startswith("litellm:"): |
| 102 | + return LiteLLM(value.split(":")[1]) |
| 103 | + |
| 104 | + llm = load_object(value) |
| 105 | + if not isinstance(llm, LLM): |
| 106 | + raise BadParameter("The LLM object must be an instance of the LLM class.") |
| 107 | + return llm |
| 108 | + |
| 109 | + |
| 110 | +def validate_similarity_index_factory( |
| 111 | + _ctx: Context, _param: Option, value: Union[str, Callable[[Engine, Table, Column], SimilarityIndex]] |
| 112 | +) -> Optional[Callable[[Engine, Table, Column], SimilarityIndex]]: |
| 113 | + """ |
| 114 | + Validate the similarity index factory. |
| 115 | +
|
| 116 | + Args: |
| 117 | + value: The value of the option. |
| 118 | +
|
| 119 | + Returns: |
| 120 | + The validated similarity index factory. |
| 121 | + """ |
| 122 | + if callable(value): |
| 123 | + return value |
| 124 | + if value == "None" or value is None: |
| 125 | + return None |
| 126 | + if value == "faiss": |
| 127 | + return faiss_builder |
| 128 | + |
| 129 | + index_builder = load_object(value) if value else None |
| 130 | + if not callable(index_builder): |
| 131 | + raise BadParameter("The similarity index factory must be a callable object.") |
| 132 | + return index_builder |
| 133 | + |
| 134 | + |
| 135 | +@click.command(help="Generate a Text2SQL view definition file.") |
| 136 | +@click.option( |
| 137 | + "--file_path", |
| 138 | + default="text2sql_view.py", |
| 139 | + show_default=True, |
| 140 | + prompt="File path", |
| 141 | + help="The path to the file where the view will be generated.", |
| 142 | + callback=validate_file_path, |
| 143 | +) |
| 144 | +@click.option( |
| 145 | + "--db", |
| 146 | + default="sqlite://", |
| 147 | + show_default=True, |
| 148 | + prompt="Database URL", |
| 149 | + help="The database connection string.", |
| 150 | + callback=validate_db_url, |
| 151 | + type=click.UNPROCESSED, |
| 152 | +) |
| 153 | +@click.option( |
| 154 | + "--llm", |
| 155 | + default="None", |
| 156 | + show_default=True, |
| 157 | + prompt="LLM object", |
| 158 | + help="The path to the LLM object.", |
| 159 | + callback=validate_llm_object, |
| 160 | + type=click.UNPROCESSED, |
| 161 | +) |
| 162 | +@click.option( |
| 163 | + "--llm_description", |
| 164 | + is_flag=True, |
| 165 | + default=False, |
| 166 | + show_default=True, |
| 167 | + prompt="LLM table description?", |
| 168 | + help="Generate tables description using LLM.", |
| 169 | +) |
| 170 | +@click.option( |
| 171 | + "--similarity_index_factory", |
| 172 | + default="None", |
| 173 | + show_default=True, |
| 174 | + prompt="Similarity index factory", |
| 175 | + help="The path to the similarity index factory.", |
| 176 | + callback=validate_similarity_index_factory, |
| 177 | + type=click.UNPROCESSED, |
| 178 | +) |
| 179 | +def generate_text2sql_view( |
| 180 | + file_path: str, |
| 181 | + db: Engine, |
| 182 | + llm: Optional[LLM], |
| 183 | + llm_description: bool, |
| 184 | + similarity_index_factory: Optional[Callable[[Engine, sqlalchemy.Table, sqlalchemy.Column], SimilarityIndex]], |
| 185 | +) -> None: |
| 186 | + """ |
| 187 | + Generate a Text2SQL view definition file. |
| 188 | +
|
| 189 | + Args: |
| 190 | + file_path: The path to the file where the view will be generated. |
| 191 | + db: The database connection string. |
| 192 | + llm: The path to the LLM object. |
| 193 | + llm_description: Generate a description using the LLM object. |
| 194 | + similarity_index_factory: The path to the similarity index factory. |
| 195 | + """ |
| 196 | + builder = configure_text2sql_auto_discovery(db) |
| 197 | + if llm: |
| 198 | + builder = builder.use_llm(llm) |
| 199 | + if llm_description: |
| 200 | + builder = builder.generate_description_by_llm() |
| 201 | + if similarity_index_factory: |
| 202 | + builder = builder.suggest_similarity_indexes(similarity_index_factory) |
| 203 | + |
| 204 | + click.echo("Discovering tables...") |
| 205 | + tables = asyncio.run(builder.discover()) |
| 206 | + click.echo(f"Discovered {len(tables)} tables.") |
| 207 | + |
| 208 | + click.echo("Generating Text2SQL view...") |
| 209 | + generator = Text2SQLViewGenerator(tables) |
| 210 | + code = generator.generate() |
| 211 | + |
| 212 | + dirs, _ = os.path.split(file_path) |
| 213 | + if dirs: |
| 214 | + os.makedirs(dirs, exist_ok=True) |
| 215 | + |
| 216 | + with open(file_path, "w", encoding="utf-8") as file: |
| 217 | + file.write(code) |
| 218 | + |
| 219 | + click.echo(f"Generated Text2SQL view in {file_path}.") |
| 220 | + |
| 221 | + |
| 222 | +def load_object(path: str) -> object: |
| 223 | + """ |
| 224 | + Load an object from a module. |
| 225 | +
|
| 226 | + Args: |
| 227 | + path: The path to the object in the format 'module:object'. |
| 228 | +
|
| 229 | + Returns: |
| 230 | + The object. |
| 231 | +
|
| 232 | + Raises: |
| 233 | + BadParameter: If the object is not found. |
| 234 | + """ |
| 235 | + try: |
| 236 | + module_name, object_name = path.split(":") |
| 237 | + except ValueError as exc: |
| 238 | + raise BadParameter("The object must be in the format 'module:object'.") from exc |
| 239 | + |
| 240 | + try: |
| 241 | + module = importlib.import_module(module_name) |
| 242 | + except ModuleNotFoundError as exc: |
| 243 | + raise BadParameter(f"Could not find the module '{module_name}'.") from exc |
| 244 | + |
| 245 | + try: |
| 246 | + return getattr(module, object_name) |
| 247 | + except AttributeError as exc: |
| 248 | + raise BadParameter(f"Could not find the '{object_name}' object in the '{module_name}' module.") from exc |
0 commit comments