Skip to content

Commit 6dfeda0

Browse files
author
Adrià Garriga-Alonso
authored
Merge pull request #3 from AlignmentResearch/feature/support-sqlite-and-postgres
Add branch for sqllite and postgres
2 parents ceec06c + 9cb33c5 commit 6dfeda0

File tree

4 files changed

+57
-31
lines changed

4 files changed

+57
-31
lines changed

aim/storage/migrations/env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from alembic.config import Config
1010
from sqlalchemy import create_engine
1111

12-
import aim.storage.drop_table_cascade # noqa: F401
12+
if os.environ.get("AIM_USE_PG", False):
13+
import aim.storage.drop_table_cascade # noqa: F401
1314

1415

1516
# this is the Alembic Config object, which provides

aim/storage/structured/db.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from sqlalchemy import create_engine, event
1313
from sqlalchemy.orm import scoped_session, sessionmaker
1414

15-
import aim.storage.drop_table_cascade # noqa: F401
15+
if os.environ.get("AIM_USE_PG", False):
16+
import aim.storage.drop_table_cascade # noqa: F401
1617

1718
class ObjectCache:
1819
def __init__(self, data_fetch_func, key_func):
@@ -47,8 +48,6 @@ def __getitem__(self, key):
4748

4849

4950
class DB(ObjectFactory):
50-
_DB_NAME = 'app'
51-
_DEFAULT_PORT = 5432
5251
_pool = WeakValueDictionary()
5352

5453
_caches = dict()
@@ -57,17 +56,25 @@ class DB(ObjectFactory):
5756
def __init__(self, path: str, readonly: bool = False):
5857
import logging
5958

60-
super().__init__()
61-
pg_dbname = os.environ['AIM_PG_DBNAME_RUNS']
62-
self.path = pg_dbname
63-
self.db_url = self.get_db_url(self.path)
59+
super().__init__()
60+
if os.environ.get("AIM_USE_PG", False):
61+
self.path = os.environ['AIM_PG_DBNAME_RUNS']
62+
engine_options = {
63+
"pool_pre_ping": True,
64+
}
65+
else:
66+
self.path = path
67+
engine_options = {
68+
"pool_size": 10,
69+
"max_overflow": 20,
70+
}
71+
72+
self.db_url = self.get_db_url(self.path)
6473
self.readonly = readonly
6574
self.engine = create_engine(
6675
self.db_url,
6776
echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))),
68-
pool_pre_ping=True
69-
# pool_size=10,
70-
# max_overflow=20,
77+
**engine_options,
7178
)
7279
event.listen(self.engine, 'connect', lambda c, _: c.execute('pragma foreign_keys=on'))
7380
self.session_cls = scoped_session(sessionmaker(autoflush=False, bind=self.engine))
@@ -82,18 +89,26 @@ def from_path(cls, path: str, readonly: bool = False):
8289
return db
8390

8491
@staticmethod
85-
def get_default_url():
86-
pg_dbname = os.environ['AIM_PG_DBNAME_RUNS']
87-
return DB.get_db_url(pg_dbname)
92+
def get_default_url():
93+
return DB.get_db_url(".aim")
8894

8995
@staticmethod
9096
def get_db_url(path: str) -> str:
91-
pg_user = os.environ['AIM_PG_USER']
92-
pg_password = os.environ['AIM_PG_PASSWORD']
93-
pg_host = os.environ['AIM_PG_HOST']
94-
pg_port = os.environ['AIM_PG_PORT']
97+
if os.environ.get("AIM_USE_PG", False):
98+
pg_dbname = os.environ['AIM_PG_DBNAME_RUNS']
99+
pg_user = os.environ['AIM_PG_USER']
100+
pg_password = os.environ['AIM_PG_PASSWORD']
101+
pg_host = os.environ['AIM_PG_HOST']
102+
pg_port = os.environ['AIM_PG_PORT']
103+
db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}"
104+
else:
105+
db_dialect = "sqlite"
106+
db_name = "run_metadata.sqlite"
107+
if os.path.exists(path):
108+
db_url = f'{db_dialect}:///{path}/{db_name}'
109+
else:
110+
raise RuntimeError(f'Cannot find database {path}. Please init first.')
95111

96-
db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{path}"
97112
return db_url
98113

99114
@property

aim/web/api/db.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,21 @@
99
from sqlalchemy.ext.declarative import declarative_base
1010
from sqlalchemy.orm import sessionmaker
1111

12-
import aim.storage.drop_table_cascade # noqa: F401
12+
if os.environ.get("AIM_USE_PG", False):
13+
import aim.storage.drop_table_cascade # noqa: F401
14+
engine_options = {}
15+
else:
16+
engine_options = {
17+
"connect_args": {"check_same_thread": False},
18+
"pool_size": 10,
19+
"max_overflow": 20,
20+
}
21+
1322

1423
engine = create_engine(
1524
get_db_url(),
1625
echo=(logging.INFO >= int(os.environ.get(AIM_LOG_LEVEL_KEY, logging.WARNING))),
17-
# connect_args={'check_same_thread': False},
18-
# pool_size=10,
19-
# max_overflow=20,
26+
**engine_options,
2027
)
2128

2229
SessionLocal = sessionmaker(autoflush=False, bind=engine)

aim/web/utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,14 @@ def get_root_path():
4040

4141

4242
def get_db_url():
43-
pg_user = os.environ['AIM_PG_USER']
44-
pg_password = os.environ['AIM_PG_PASSWORD']
45-
pg_host = os.environ['AIM_PG_HOST']
46-
pg_port = os.environ['AIM_PG_PORT']
47-
pg_dbname = os.environ['AIM_PG_DBNAME_WEB']
48-
49-
db_url = f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}"
50-
return db_url
43+
if os.environ.get("AIM_USE_PG", False):
44+
pg_user = os.environ['AIM_PG_USER']
45+
pg_password = os.environ['AIM_PG_PASSWORD']
46+
pg_host = os.environ['AIM_PG_HOST']
47+
pg_port = os.environ['AIM_PG_PORT']
48+
pg_dbname = os.environ['AIM_PG_DBNAME_WEB']
49+
50+
return f"postgresql://{pg_user}:{pg_password}@{pg_host}:{pg_port}/{pg_dbname}"
51+
else:
52+
return 'sqlite:///{}/{}/aim_db'.format(get_root_path(), get_aim_repo_name())
53+

0 commit comments

Comments
 (0)