|
| 1 | +import pyodbc |
| 2 | +import struct |
| 3 | +import logging |
| 4 | +from azure.identity import DefaultAzureCredential |
| 5 | +import sys |
| 6 | +import argparse |
| 7 | + |
| 8 | +def main(): |
| 9 | + |
| 10 | + # Create an ArgumentParser object |
| 11 | + parser = argparse.ArgumentParser() |
| 12 | + |
| 13 | + parser.add_argument('--server', type=str, required=True, help='SQL Server name') |
| 14 | + parser.add_argument('--database', type=str, required=True, help='Database name') |
| 15 | + parser.add_argument('--chartJsFuncAppName', type=str, required=True, help='Chart JS function name') |
| 16 | + parser.add_argument('--ragFuncAppName', type=str, required=True, help='RAG function name') |
| 17 | + |
| 18 | + # Parse the arguments |
| 19 | + args = parser.parse_args() |
| 20 | + |
| 21 | + # Access the arguments & Set up the connection string |
| 22 | + server = f"{args.server}.database.windows.net" |
| 23 | + database = args.database |
| 24 | + driver = "{ODBC Driver 17 for SQL Server}" |
| 25 | + chartJsFuncAppName = args.chartJsFuncAppName |
| 26 | + ragFuncAppName = args.ragFuncAppName |
| 27 | + |
| 28 | + # Get the token using DefaultAzureCredential |
| 29 | + # Managed Identity for mid-AzureCloud |
| 30 | + credential = DefaultAzureCredential() |
| 31 | + |
| 32 | + token_bytes = credential.get_token( |
| 33 | + "https://database.windows.net/.default" |
| 34 | + ).token.encode("utf-16-LE") |
| 35 | + token_struct = struct.pack(f"<I{len(token_bytes)}s", len(token_bytes), token_bytes) |
| 36 | + SQL_COPT_SS_ACCESS_TOKEN = ( |
| 37 | + 1256 # This connection option is defined by microsoft in msodbcsql.h |
| 38 | + ) |
| 39 | + # Set up the connection |
| 40 | + connection_string = f"DRIVER={driver};SERVER={server};DATABASE={database};" |
| 41 | + conn = pyodbc.connect( |
| 42 | + connection_string, attrs_before={SQL_COPT_SS_ACCESS_TOKEN: token_struct} |
| 43 | + ) |
| 44 | + |
| 45 | + # Create a cursor and execute SQL commands |
| 46 | + cursor = conn.cursor() |
| 47 | + |
| 48 | + sql_query_chartjs_func = f""" |
| 49 | +IF NOT EXISTS (SELECT 1 FROM sys.database_principals WHERE name = '{chartJsFuncAppName}') |
| 50 | +BEGIN |
| 51 | + CREATE USER [{chartJsFuncAppName}] FROM EXTERNAL PROVIDER; |
| 52 | + ALTER ROLE db_datareader ADD MEMBER [{chartJsFuncAppName}]; -- Grant SELECT on all user tables and views. |
| 53 | + ALTER ROLE db_datawriter ADD MEMBER [{chartJsFuncAppName}]; -- Grant INSERT, UPDATE, and DELETE on all user tables and views. |
| 54 | +END |
| 55 | +""" |
| 56 | + |
| 57 | + sql_query_rag_func = f""" |
| 58 | +IF NOT EXISTS (SELECT 1 FROM sys.database_principals WHERE name = '{ragFuncAppName}') |
| 59 | +BEGIN |
| 60 | + CREATE USER [{ragFuncAppName}] FROM EXTERNAL PROVIDER; |
| 61 | + ALTER ROLE db_datareader ADD MEMBER [{ragFuncAppName}]; -- Grant SELECT on all user tables and views. |
| 62 | + ALTER ROLE db_datawriter ADD MEMBER [{ragFuncAppName}]; -- Grant INSERT, UPDATE, and DELETE on all user tables and views. |
| 63 | +END |
| 64 | +""" |
| 65 | + # Execute SQL commands to create the user and assign the db_datareader role |
| 66 | + cursor.execute(sql_query_chartjs_func) |
| 67 | + cursor.execute(sql_query_rag_func) |
| 68 | + |
| 69 | + conn.commit() |
| 70 | + |
| 71 | + # Close the connection |
| 72 | + cursor.close() |
| 73 | + conn.close() |
| 74 | + |
| 75 | + |
| 76 | +if __name__ == "__main__": |
| 77 | + logger = logging.getLogger("azure.identity") |
| 78 | + logger.setLevel(logging.INFO) |
| 79 | + handler = logging.StreamHandler(stream=sys.stdout) |
| 80 | + logger.addHandler(handler) |
| 81 | + |
| 82 | + main() |
0 commit comments