44
55import json
66import os
7- import sqlite3
87import shutil
8+ import sqlite3
99
1010import numpy as np
1111import pandas as pd
1212import sqlalchemy .types as sqlalchemy_types
13+ import yaml
1314from astropy .coordinates import SkyCoord
1415from astropy .table import Table as AstropyTable
1516from astropy .units .quantity import Quantity
1617from sqlalchemy import Table , and_ , create_engine , event , or_ , text
1718from sqlalchemy .engine import Engine
1819from sqlalchemy .orm import declarative_base , sessionmaker
1920from sqlalchemy .orm .query import Query
21+ from sqlalchemy .schema import CreateSchema
2022from tqdm import tqdm
2123
2224from . import FOREIGN_KEY , PRIMARY_TABLE , PRIMARY_TABLE_KEY , REFERENCE_TABLES
@@ -166,7 +168,7 @@ def load_connection(connection_string, sqlite_foreign=True, base=None, connectio
166168 session = Session ()
167169
168170 # Enable foreign key checks in SQLite
169- if "sqlite" in connection_string and sqlite_foreign :
171+ if connection_string . startswith ( "sqlite" ) and sqlite_foreign :
170172 set_sqlite ()
171173 # elif 'postgresql' in connection_string:
172174 # # Set up schema in postgres (must be lower case?)
@@ -189,23 +191,60 @@ def set_sqlite_pragma(dbapi_connection, connection_record):
189191 cursor .close ()
190192
191193
192- def create_database (connection_string , drop_tables = False ):
194+ def create_database (connection_string , drop_tables = False , felis_schema = None ):
193195 """
194196 Create a database from a schema that utilizes the `astrodbkit2.astrodb.Base` class.
195197 Some databases, eg Postgres, must already exist but any tables should be dropped.
198+ The default behavior is to assume that a schema with SQLAlchemy definitions has been imported prior to calling this function.
199+ If instead, Felis is being used to define the schema, the path to the YAML file needs to be provided to the felis_schema parameter (as a string).
196200
197201 Parameters
198202 ----------
199203 connection_string : str
200204 Connection string to database
201205 drop_tables : bool
202206 Flag to drop existing tables. This is needed when the schema changes. (Default: False)
207+ felis_schema : str
208+ Path to schema yaml file
203209 """
204210
205- session , base , engine = load_connection (connection_string , base = Base )
206- if drop_tables :
207- base .metadata .drop_all ()
208- base .metadata .create_all (engine ) # this explicitly creates the database
211+ if felis_schema is not None :
212+ # Felis loader requires felis_schema
213+ from felis .datamodel import Schema
214+ from felis .metadata import MetaDataBuilder
215+
216+ # Load and validate the felis-formatted schema
217+ data = yaml .safe_load (open (felis_schema , "r" ))
218+ schema = Schema .model_validate (data )
219+ schema_name = data ["name" ] # get schema_name from the felis schema file
220+
221+ # engine = create_engine(connection_string)
222+ session , base , engine = load_connection (connection_string )
223+
224+ # Schema handling for various database types
225+ if connection_string .startswith ("sqlite" ):
226+ db_name = connection_string .split ("/" )[- 1 ]
227+ with engine .begin () as conn :
228+ conn .execute (text (f"ATTACH '{ db_name } ' AS { schema_name } " ))
229+ elif connection_string .startswith ("postgres" ):
230+ with engine .connect () as connection :
231+ connection .execute (CreateSchema (schema_name , if_not_exists = True ))
232+ connection .commit ()
233+
234+ # Drop tables, if requested
235+ if drop_tables :
236+ base .metadata .drop_all ()
237+
238+ # Create the database
239+ metadata = MetaDataBuilder (schema ).build ()
240+ metadata .create_all (bind = engine )
241+ base .metadata = metadata
242+ else :
243+ session , base , engine = load_connection (connection_string , base = Base )
244+ if drop_tables :
245+ base .metadata .drop_all ()
246+ base .metadata .create_all (engine ) # this explicitly creates the database
247+
209248 return session , base , engine
210249
211250
@@ -276,6 +315,7 @@ def __init__(
276315 column_type_overrides = {},
277316 sqlite_foreign = True ,
278317 connection_arguments = {},
318+ schema = None ,
279319 ):
280320 """
281321 Wrapper for database calls and utility functions
@@ -301,8 +341,15 @@ def __init__(
301341 Flag to enable/disable use of foreign keys with SQLite. Default: True
302342 connection_arguments : dict
303343 Additional connection arguments, like {'check_same_thread': False}. Default: {}
344+ schema : str
345+ Helper for setting default PostgreSQL schema. Equivalent to connection_arguments={"options": f"-csearch_path={schema}"}
304346 """
305347
348+ # Helper logic to set default postgres schema, if specified
349+ if connection_string .startswith ("postgres" ) and schema is not None :
350+ if connection_arguments .get ("options" ) is None :
351+ connection_arguments ["options" ] = f"-csearch_path={ schema } "
352+
306353 if connection_string == "sqlite://" :
307354 self .session , self .base , self .engine = create_database (connection_string )
308355 else :
0 commit comments