|
1 | 1 | """Utility functions for the text2sql API.""" |
2 | 2 | import json |
3 | | -from typing import Any, Dict, List |
| 3 | +from typing import Dict, List, Optional, TypedDict |
4 | 4 |
|
5 | 5 | from litellm import completion, batch_completion |
6 | 6 |
|
7 | 7 | from api.config import Config |
8 | 8 |
|
9 | 9 |
|
10 | | -def create_combined_description( |
11 | | - table_info: Dict[str, Dict[str, Any]], batch_size: int = 10 |
12 | | -) -> Dict[str, Dict[str, Any]]: |
| 10 | +class ForeignKeyInfo(TypedDict): |
| 11 | + """Foreign key constraint information.""" |
| 12 | + constraint_name: str |
| 13 | + column: str |
| 14 | + referenced_table: str |
| 15 | + referenced_column: str |
| 16 | + |
| 17 | + |
| 18 | +class ColumnInfo(TypedDict): |
| 19 | + """Column metadata information.""" |
| 20 | + type: str |
| 21 | + null: str |
| 22 | + key: str |
| 23 | + description: str |
| 24 | + default: Optional[str] |
| 25 | + sample_values: List[str] |
| 26 | + |
| 27 | + |
| 28 | +class TableInfo(TypedDict): |
| 29 | + """Table metadata information.""" |
| 30 | + description: str |
| 31 | + columns: Dict[str, ColumnInfo] |
| 32 | + foreign_keys: List[ForeignKeyInfo] |
| 33 | + col_descriptions: List[str] |
| 34 | + |
| 35 | + |
| 36 | +def create_combined_description( # pylint: disable=too-many-locals |
| 37 | + table_info: Dict[str, TableInfo], batch_size: int = 10 |
| 38 | +) -> Dict[str, TableInfo]: |
13 | 39 | """ |
14 | 40 | Create a combined description from a dictionary of table descriptions. |
15 | 41 |
|
16 | 42 | Args: |
17 | | - table_info (Dict[str, Dict[str, Any]]): Mapping of table names to their metadata. |
| 43 | + table_info (Dict[str, TableInfo]): Mapping of table names to their metadata. |
18 | 44 | batch_size (int): Number of tables to process per batch when calling the LLM (default: 10). |
19 | 45 | Returns: |
20 | | - Dict[str, Dict[str, Any]]: Updated mapping containing descriptions. |
| 46 | + Dict[str, TableInfo]: Updated mapping containing descriptions. |
21 | 47 | """ |
22 | 48 | if not isinstance(table_info, dict): |
23 | 49 | raise TypeError("table_info must be a dictionary keyed by table name.") |
|
0 commit comments