Skip to content

Commit fb7408f

Browse files
authored
Merge pull request #412 from hackforla/fix-build
Refactor data access layer
2 parents 8f5c5fc + 2bc5771 commit fb7408f

File tree

3 files changed

+23
-16
lines changed

3 files changed

+23
-16
lines changed

api/openapi_server/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from os import environ as env
55

66
from openapi_server import encoder
7+
from openapi_server.models.database import DataAccessLayer
78
from openapi_server.exceptions import AuthError, handle_auth_error
89
from dotenv import load_dotenv, find_dotenv
910

@@ -13,6 +14,7 @@
1314
load_dotenv(ENV_FILE)
1415
SECRET_KEY=env.get('SECRET_KEY')
1516

17+
DataAccessLayer.db_init()
1618

1719
def main():
1820
app = connexion.App(__name__, specification_dir='./_spec/')

api/openapi_server/controllers/service_provider_controller.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from openapi_server.models import database as db
1010
from sqlalchemy.orm import Session
1111

12-
dal = db.DataAccessLayer()
13-
dal.db_init()
12+
db_engine = db.DataAccessLayer.get_engine()
1413

1514
def create_service_provider(): # noqa: E501
1615
"""Create a housing program service provider
@@ -27,7 +26,7 @@ def create_service_provider(): # noqa: E501
2726
connexion.request.get_json()).to_dict()
2827
except ValueError:
2928
return traceback.format_exc(ValueError), 400
30-
with Session(dal.engine) as session:
29+
with Session(db_engine) as session:
3130
row = db.HousingProgramServiceProvider(
3231
provider_name=provider["provider_name"]
3332
)
@@ -51,7 +50,7 @@ def delete_service_provider(provider_id): # noqa: E501
5150
5251
:rtype: None
5352
"""
54-
with Session(dal.engine) as session:
53+
with Session(db_engine) as session:
5554
query = session.query(
5655
db.HousingProgramServiceProvider).filter(
5756
db.HousingProgramServiceProvider.id == provider_id)
@@ -71,7 +70,7 @@ def get_service_provider_by_id(provider_id): # noqa: E501
7170
7271
:rtype: ServiceProviderWithId
7372
"""
74-
with Session(dal.engine) as session:
73+
with Session(db_engine) as session:
7574
row = session.get(
7675
db.HousingProgramServiceProvider, provider_id)
7776
if row != None:
@@ -93,7 +92,7 @@ def get_service_providers(): # noqa: E501
9392
:rtype: List[ServiceProviderWithId]
9493
"""
9594
resp = []
96-
with Session(dal.engine) as session:
95+
with Session(db_engine) as session:
9796
table = session.query(db.HousingProgramServiceProvider).all()
9897
for row in table:
9998
provider = ServiceProvider(
@@ -122,7 +121,7 @@ def update_service_provider(provider_id): # noqa: E501
122121
connexion.request.get_json()).to_dict()
123122
except ValueError:
124123
return traceback.format_exc(ValueError), 400
125-
with Session(dal.engine) as session:
124+
with Session(db_engine) as session:
126125
query = session.query(
127126
db.HousingProgramServiceProvider).filter(
128127
db.HousingProgramServiceProvider.id == provider_id)

api/openapi_server/models/database.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,16 +233,22 @@ class ProgramCaseStatusLog(Base):
233233
src_status = Column(Integer, ForeignKey('case_status.id'), nullable=False)
234234
dest_status = Column(Integer, ForeignKey('case_status.id'), nullable=False)
235235

236-
237-
238236
class DataAccessLayer:
239-
connection = None
240-
engine = None
237+
_engine = None
241238

242239
# temporary local sqlite DB, replace with conn str for postgres container port for real e2e
243-
conn_string = "sqlite:///./homeuniteus.db"
240+
_conn_string = "sqlite:///./homeuniteus.db"
244241

245-
def db_init(self, conn_string=None):
246-
self.engine = create_engine(conn_string or self.conn_string, echo=True, future=True)
247-
Base.metadata.create_all(bind=self.engine)
248-
self.connection = self.engine.connect()
242+
@classmethod
243+
def db_init(cls, conn_string=None):
244+
Base.metadata.create_all(bind=cls.get_engine(conn_string))
245+
246+
@classmethod
247+
def connect(cls):
248+
return cls.get_engine().connect()
249+
250+
@classmethod
251+
def get_engine(cls, conn_string=None):
252+
if cls._engine == None:
253+
cls._engine = create_engine(conn_string or cls._conn_string, echo=True, future=True)
254+
return cls._engine

0 commit comments

Comments
 (0)