1212from sqlalchemy import create_engine , event
1313from sqlalchemy .orm import scoped_session , sessionmaker
1414
15- import aim .storage .drop_table_cascade # noqa: F401
15+ if os .environ .get ("AIM_USE_PG" , False ):
16+ import aim .storage .drop_table_cascade # noqa: F401
1617
1718class ObjectCache :
1819 def __init__ (self , data_fetch_func , key_func ):
@@ -47,8 +48,6 @@ def __getitem__(self, key):
4748
4849
4950class DB (ObjectFactory ):
50- _DB_NAME = 'app'
51- _DEFAULT_PORT = 5432
5251 _pool = WeakValueDictionary ()
5352
5453 _caches = dict ()
@@ -57,17 +56,25 @@ class DB(ObjectFactory):
5756 def __init__ (self , path : str , readonly : bool = False ):
5857 import logging
5958
60- super ().__init__ ()
61- pg_dbname = os .environ ['AIM_PG_DBNAME_RUNS' ]
62- self .path = pg_dbname
63- self .db_url = self .get_db_url (self .path )
59+ super ().__init__ ()
60+ if os .environ .get ("AIM_USE_PG" , False ):
61+ self .path = os .environ ['AIM_PG_DBNAME_RUNS' ]
62+ engine_options = {
63+ "pool_pre_ping" : True ,
64+ }
65+ else :
66+ self .path = path
67+ engine_options = {
68+ "pool_size" : 10 ,
69+ "max_overflow" : 20 ,
70+ }
71+
72+ self .db_url = self .get_db_url (self .path )
6473 self .readonly = readonly
6574 self .engine = create_engine (
6675 self .db_url ,
6776 echo = (logging .INFO >= int (os .environ .get (AIM_LOG_LEVEL_KEY , logging .WARNING ))),
68- pool_pre_ping = True
69- # pool_size=10,
70- # max_overflow=20,
77+ ** engine_options ,
7178 )
7279 event .listen (self .engine , 'connect' , lambda c , _ : c .execute ('pragma foreign_keys=on' ))
7380 self .session_cls = scoped_session (sessionmaker (autoflush = False , bind = self .engine ))
@@ -82,18 +89,26 @@ def from_path(cls, path: str, readonly: bool = False):
8289 return db
8390
8491 @staticmethod
85- def get_default_url ():
86- pg_dbname = os .environ ['AIM_PG_DBNAME_RUNS' ]
87- return DB .get_db_url (pg_dbname )
92+ def get_default_url ():
93+ return DB .get_db_url (".aim" )
8894
8995 @staticmethod
9096 def get_db_url (path : str ) -> str :
91- pg_user = os .environ ['AIM_PG_USER' ]
92- pg_password = os .environ ['AIM_PG_PASSWORD' ]
93- pg_host = os .environ ['AIM_PG_HOST' ]
94- pg_port = os .environ ['AIM_PG_PORT' ]
97+ if os .environ .get ("AIM_USE_PG" , False ):
98+ pg_dbname = os .environ ['AIM_PG_DBNAME_RUNS' ]
99+ pg_user = os .environ ['AIM_PG_USER' ]
100+ pg_password = os .environ ['AIM_PG_PASSWORD' ]
101+ pg_host = os .environ ['AIM_PG_HOST' ]
102+ pg_port = os .environ ['AIM_PG_PORT' ]
103+ db_url = f"postgresql://{ pg_user } :{ pg_password } @{ pg_host } :{ pg_port } /{ pg_dbname } "
104+ else :
105+ db_dialect = "sqlite"
106+ db_name = "run_metadata.sqlite"
107+ if os .path .exists (path ):
108+ db_url = f'{ db_dialect } :///{ path } /{ db_name } '
109+ else :
110+ raise RuntimeError (f'Cannot find database { path } . Please init first.' )
95111
96- db_url = f"postgresql://{ pg_user } :{ pg_password } @{ pg_host } :{ pg_port } /{ path } "
97112 return db_url
98113
99114 @property
0 commit comments