Skip to content

Commit c92ce69

Browse files
authored
🏗️ Enterprise Flask Backend Refactoring - Production Ready for FastAPI Integration (#3)
* Lay the groundwork for refactoring the Flask backend * Refactor and modularize the Flask backend * Cleanup: Remove obsolete test helper files * Feat: Add new test infrastructure and fixtures * Refactor(db): Update database connection and SQL query logic * Refactor(api): Modularize company routes and wiki client adapter * Refactor(tests): Update existing tests to conform to new structure * Chore: Update gitignore file * Further update gitignore file * Chore: Update S&P 500 wiki information CSV * Refactor(db): Adjust database models and schema definitions * Refactor(queries): Update database query logic and SQL strings * Refactor(api): Adjust API routes and data handling in endpoints * Refactor(tests): Update testing infrastructure and dependencies * Chore: Add pytest_output.txt to gitignore * Feat(tests): Add new query tests and initial API directory structure * feat: Complete Flask backend refactoring with comprehensive test suite - Finalize pytest infrastructure with database mocking - Enable full test coverage for modular Blueprint architecture - Prepare foundation for FastAPI + Airflow parallel development - Support OpenBB Platform (Linux Foundation FINOS member) integration patterns - Resolve test suite blockers through comprehensive mock data configuration
1 parent 3a0ce72 commit c92ce69

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+889
-1012
lines changed

‎.gitignore‎

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
1+
# Operating System / Editor specific
12
.DS_Store
3+
.vscode/
4+
__pycache__/
5+
*.pyc
26
.ipynb_checkpoints
7+
8+
# Environment variables
9+
*.env
10+
frontend/.env
311
backend/.env
412

13+
# Configuration files
514
./api/src/adapters/api_calls/api_client_config.py
15+
16+
#Virtual environments
17+
.venv/
18+
venv/
19+
frontend/.venv/
20+
backend/.venv/
21+
pytest_output.txt
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
,Ticker,Security,SEC filings,GICS Sector,GICS Sub-Industry,Employees
2+
0,MMM,3M Company,reports,Industrials,Industrial Conglomerates,95000
3+
1,MSFT,Microsoft Corp,reports,Information Technology,Software,181000
4+
2,ZTS,Zoetis,reports,Health Care,Pharmaceuticals,12000

‎backend/api/data/sp500/raw_data/sp500_stocks_wiki_info.csv‎

Lines changed: 4 additions & 506 deletions
Large diffs are not rendered by default.

‎backend/api/src/__init__.py‎

Lines changed: 12 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
from api.src.models.queries.sql_query_strings import companies_within_sub_sector_str, sub_sector_names_in_sector_query_str
1515
from settings import DB_HOST, DB_NAME, DB_PASSWORD, DB_USER, DEBUG, TESTING
1616

17-
def create_app(database='investment_analysis', testing=False, debug=True):
17+
def create_app(db_name='investment_analysis', db_user='postgres', db_password='postgres', testing=False):
1818
"""Create and configure an instance of the Flask application."""
1919
app = Flask(__name__)
2020

2121
# connect to the local computer's Postgres
2222
app.config.from_mapping(
23-
DB_USER = 'postgres',
24-
DB_NAME = database,
25-
DB_PASSWORD = 'postgres',
23+
DB_USER = db_user,
24+
DB_NAME = db_name,
25+
DB_PASSWORD = db_password,
2626
DB_HOST = '127.0.0.1',
2727
DEBUG = DEBUG,
2828
TESTING = TESTING
@@ -42,93 +42,13 @@ def create_app(database='investment_analysis', testing=False, debug=True):
4242

4343
@app.route('/')
4444
def root_url():
45-
return 'Welcome to the Economic Analysis api, through the prism of the S&P 500 stocks performance.'
46-
47-
@app.route('/sectors/')
48-
def sector_avg_financial_performance():
49-
"""
50-
url parameter format: f'/sectors/?financial_indicator={financial_indicator_name}'
51-
returns the quarterly average, over the most recent 8 quarters, of the financial indicator of each and every sector
52-
"""
53-
conn, cursor, financial_indicator = financial_performance_query_tools()
54-
historical_financials_json_dicts = get_historical_financials_json(financial_indicator, cursor)
55-
return json.dumps(historical_financials_json_dicts, default = str)
56-
57-
def get_historical_financials_json(financial_indicator, cursor):
58-
if financial_indicator in ['revenue', 'net_income', 'earnings_per_share', 'profit_margin']:
59-
historical_financials_json_dicts = (models.SubIndustry.
60-
find_avg_quarterly_financials_by_sector(financial_indicator, cursor))
61-
elif financial_indicator in ['closing_price', 'price_earnings_ratio']:
62-
historical_financials_json_dicts = (models.SubIndustry.
63-
find_sector_avg_price_pe(financial_indicator, cursor))
64-
# needs to handle dropdown menu selection of 'Done. Continue to the sub-Sector level.'
65-
else:
66-
historical_financials_json_dicts = 'Please enter the name of a financial_indicator, such as revenue, net_income.'
67-
return historical_financials_json_dicts
68-
69-
@app.route('/sectors/search')
70-
def sub_industries_within_sector():
71-
"""
72-
url parameter format example: /sectors/search?sector_name=Energy&financial_indicator=revenue
73-
returns the quarterly average, over the most recent 8 quarters, of the selected financial indicator and sector
74-
"""
75-
conn, cursor, sector_name, financial_indicator = sub_sector_performance_query_tools()
76-
if sector_name == 'all_sectors':
77-
conn = db.get_db()
78-
cursor = conn.cursor()
79-
sector_names = MixinSectorPricePE.get_all_sector_names(models.SubIndustry, cursor)
80-
return {'all_sector_names': sector_names}
81-
else:
82-
if financial_indicator in ['revenue', 'net_income', 'earnings_per_share', 'profit_margin']:
83-
historical_financials_json_dicts = (models.SubIndustry.
84-
find_avg_quarterly_financials_by_sub_industry(sector_name, financial_indicator, cursor))
85-
elif financial_indicator in ['closing_price', 'price_earnings_ratio']:
86-
historical_financials_json_dicts = (models.SubIndustry.
87-
find_sub_industry_avg_quarterly_price_pe(sector_name, financial_indicator, cursor))
88-
else:
89-
historical_financials_json_dicts = {'Please enter the name of a financial indicator.'}
90-
return json.dumps(historical_financials_json_dicts, default = str)
91-
92-
@app.route('/sub_sectors/search')
93-
def search_sub_sectors():
94-
conn, cursor, sub_sector_name, financial_indicator = company_performance_query_tools()
95-
if sub_sector_name == 'all_sub_sectors':
96-
sector_name = financial_indicator
97-
sub_sector_names = MixinSubSectorPricePE.get_sub_sector_names_of_sector(models.SubIndustry, sector_name, cursor)
98-
return json.dumps({'sub_sector_names': sub_sector_names}, default=str)
99-
else:
100-
if financial_indicator in ['revenue', 'net_income', 'earnings_per_share', 'profit_margin']:
101-
historical_financials_json_dicts = (models.Company.
102-
find_companies_quarterly_financials(sub_sector_name, financial_indicator, cursor))
103-
elif financial_indicator in ['closing_price', 'price_earnings_ratio']:
104-
historical_financials_json_dicts = (models.Company.find_company_quarterly_price_pe(sub_sector_name, financial_indicator, cursor))
105-
else:
106-
historical_financials_json_dicts = {'Please enter the name of a financial indicator.'}
107-
return json.dumps(historical_financials_json_dicts, default = str)
108-
109-
@app.route('/sub_sectors/<sub_industry_name>')
110-
def company_financial_performance(sub_industry_name):
111-
conn, cursor, financial_indicator = financial_performance_query_tools()
112-
if sub_industry_name == 'all_sub_industries':
113-
sector_name = financial_indicator
114-
sub_industry_names = MixinCompanyFinancialsPricePE.get_all_sub_sector_names_in_sector(models.Company, sector_name, cursor)
115-
return json.dumps({'sub_industry_names': sub_industry_names}, default=str)
116-
else:
117-
if financial_indicator in ['revenue', 'net_income', 'earnings_per_share', 'profit_margin']:
118-
historical_financials_json_dicts = (models.SubIndustry.
119-
find_companies_quarterly_financials(sub_sector_name, financial_indicator, cursor))
120-
elif financial_indicator in ['closing_price', 'price_earnings_ratio']:
121-
historical_financials_json_dicts = (models.SubIndustry.
122-
find_company_quarterly_price_pe(sector_name, financial_indicator, cursor))
123-
else:
124-
historical_financials_json_dicts = {'Please enter the name of a financial indicator.'}
125-
return json.dumps(historical_financials_json_dicts, default = str)
126-
127-
@app.route('/sectors/<sector_name>')
128-
def get_sub_sector_names_within_sector(sector_name):
129-
conn = db.get_db()
130-
cursor = conn.cursor()
131-
sub_sector_names = MixinSubSectorPricePE.get_sub_sector_names_of_sector(models.SubIndustry, sector_name, cursor)
132-
return json.dumps({'sub_sector_names': sub_sector_names}, default=str)
45+
return {'message': 'API is running.'}
46+
47+
from .routes.sector_routes import sector_bp
48+
from .routes.sub_sector_routes import sub_sector_bp
49+
from .routes.company_routes import company_bp
50+
app.register_blueprint(sector_bp)
51+
app.register_blueprint(sub_sector_bp)
52+
app.register_blueprint(company_bp)
13353

13454
return app

‎backend/api/src/adapters/wiki_page_client.py‎

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# coding: utf-8
33

44

5+
import os
56
import pandas as pd
7+
import requests
68

79
def ingest_sp500_stocks_info():
810
"""
@@ -18,26 +20,34 @@ def ingest_sp500_stocks_info():
1820
"""
1921

2022
sp500_wiki_data_filepath = "./api/data/sp500/raw_data/sp500_stocks_wiki_info.csv"
21-
with open(sp500_wiki_data_filepath) as existing_file:
22-
if not existing_file:
23-
sp500_df = get_sp500_wiki_info()
24-
employees_total_df = get_employees_total()
25-
sp500_incl_employees_df = merge_df(sp500_df, employees_total_df)
26-
sp500_wiki_data_filepath = save_csv(sp500_incl_employees_df, sp500_wiki_data_filepath)
23+
import os
24+
if not os.path.exists(sp500_wiki_data_filepath):
25+
sp500_df = get_sp500_wiki_info()
26+
employees_total_df = get_employees_total()
27+
sp500_incl_employees_df = merge_df(sp500_df, employees_total_df)
28+
sp500_wiki_data_filepath = save_csv(sp500_incl_employees_df, sp500_wiki_data_filepath)
2729
return sp500_wiki_data_filepath
2830

31+
import requests
32+
2933
def get_sp500_wiki_info():
3034
"""ingest each and every S&P 500 company's basic info from the Wikipedia web page"""
31-
sp500_df = pd.read_html('https://en.wikipedia.org/wiki/List_of_S%26P_500_companies')[0]
35+
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'}
36+
url = 'https://en.wikipedia.org/wiki/List_of_S%26P_500_companies'
37+
response = requests.get(url, headers=headers)
38+
sp500_df = pd.read_html(response.text)[0]
3239
column_names = list(sp500_df.columns)
3340
column_names[0] = 'Ticker'
3441
sp500_df.columns = column_names
3542
return sp500_df
3643

3744
def get_employees_total():
3845
""" ingest each company's total number of employees """
39-
returned_dataframes = pd.read_html('https://www.liberatedstocktrader.com/sp-500-companies-list-by-number-of-employees/')
40-
employees_total = returned_dataframes[2]
46+
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'}
47+
url = 'https://www.liberatedstocktrader.com/sp-500-companies-list-by-number-of-employees/'
48+
response = requests.get(url, headers=headers)
49+
returned_dataframes = pd.read_html(response.text)
50+
employees_total = returned_dataframes[0]
4151
employees_total_df = employees_total.iloc[1:, 1:].copy()
4252
employees_total_df.columns = employees_total.iloc[0, 1:]
4353
return employees_total_df
@@ -51,10 +61,11 @@ def merge_df(sp500_df, employees_total_df):
5161
sp500_incl_employees_df = sp500_incl_employees_df[security_col_notna]
5262
return sp500_incl_employees_df
5363

54-
def save_csv(sp500_incl_employees_df, sp500_wiki_data_filepath):
55-
# save the merged dataframe in a csv file
56-
sp500_incl_employees_df.to_csv(sp500_wiki_data_filepath)
57-
return sp500_wiki_data_filepath
64+
def save_csv(sp500_incl_employees_df, filepath):
65+
"""Save the dataframe to a CSV file."""
66+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
67+
sp500_incl_employees_df.to_csv(filepath)
68+
return filepath
5869

5970
if __name__ == "__main__":
6071
ingest_sp500_stocks_info()

‎backend/api/src/db/db.py‎

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
1-
from flask import current_app
2-
from flask import g
1+
from flask import current_app, g
32
import psycopg2
43
from datetime import datetime, timedelta
5-
from settings import DB_USER, DB_NAME, DB_HOST, DB_PASSWORD, DEBUG, TESTING # backend/settings.py
6-
7-
# Connecting to Postgres on local Mac (parameters hard-coded):
8-
conn = psycopg2.connect(database = 'investment_analysis', user = 'postgres', password = 'postgres')
9-
cursor = conn.cursor()
10-
11-
def get_db():
12-
if "db" not in g:
13-
# connect to postgres on the local computer
14-
g.db = psycopg2.connect(user = 'postgres', password = 'postgres',
15-
dbname = current_app.config['DB_NAME']) # apply this to user, password in __init__.py (at the top of this script, already imported from SETTINGS)
16-
17-
"""
18-
# connect to postgres on the AWS RDS instance
19-
g.db = psycopg2.connect(user = 'postgres', password = 'postgres',
20-
dbname = current_app.config['DATABASE'])
21-
"""
4+
# from settings import DB_USER, DB_NAME, DB_HOST, DB_PASSWORD, DEBUG, TESTING # backend/settings.py
5+
6+
def get_db(db_name=None, db_user=None, db_password=None):
7+
if 'db' not in g:
8+
if db_name is None:
9+
db_name = current_app.config.get('DB_NAME', 'investment_analysis')
10+
if db_user is None:
11+
db_user = current_app.config.get('DB_USER', 'postgres')
12+
if db_password is None:
13+
db_password = current_app.config.get('DB_PASSWORD', 'postgres')
14+
g.db = psycopg2.connect(host='localhost', database=db_name, user=db_user, password=db_password)
2215
return g.db
2316

2417
"""
@@ -41,8 +34,7 @@ def close_db(e=None):
4134
def build_from_record(Class, record):
4235
if not record: return None
4336
attr = dict(zip(Class.columns, record))
44-
obj = Class()
45-
obj.__dict__ = attr
37+
obj = Class(**attr)
4638
return obj
4739

4840
def build_from_records(Class, records):

‎backend/api/src/models/company.py‎

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -35,32 +35,30 @@ def find_by_company_id(self, company_id, cursor):
3535
record = cursor.fetchone()
3636
return db.build_from_record(models.Company, record)
3737

38-
@classmethod
39-
def to_company_financials_history_json(self, sub_industry_name, cursor):
40-
# return in json format the financials and stock price, price-earnings-ratios of all the companies in a sub_industry
41-
company_names = MixinCompanyPricePE.get_all_company_names_in_sub_sector(sub_industry_name, cursor)
42-
companies_quarterly_financials_dict = {}
43-
for company_name in company_names:
44-
companies_quarterly_financials_dict[company_name] = to_quarterly_financials_json(self, company_name, cursor)
45-
return companies_quarterly_financials_dict
46-
4738
@classmethod
4839
def to_quarterly_financials_json(self, company_name, cursor):
49-
quarterly_financials_json = self.__dict__
50-
quarterly_reports_obj = self.get_company_quarterly_financials(self, company_name, cursor)
51-
quarterly_financials_json['Quarterly_financials'] = [report_obj.__dict__ for report_obj in quarterly_reports_obj]
52-
prices_pe_obj = self.get_company_quarterly_prices_pe(self, company_name, cursor)
53-
quarterly_financials_json['Closing_prices_and_P/E_ratio'] = [
54-
price_pe_obj.__dict__ for price_pe_obj in prices_pe_obj]
55-
return quarterly_financials_json
40+
quarterly_reports = self.get_company_quarterly_financials(company_name, cursor)
41+
prices_pe = self.get_company_quarterly_prices_pe(company_name, cursor)
42+
43+
# Create a dictionary for prices_pe for easier lookup
44+
prices_pe_dict = {(p.year, p.quarter): p for p in prices_pe}
45+
46+
merged_data = []
47+
for report in quarterly_reports:
48+
price_pe_data = prices_pe_dict.get((report.year, report.quarter))
49+
if price_pe_data:
50+
merged_record = {**report.__dict__, **price_pe_data.__dict__}
51+
merged_data.append(merged_record)
52+
53+
return merged_data
5654

5755
@classmethod
5856
def get_company_quarterly_financials(self, company_name, cursor):
5957
sql_str = f"""
6058
SELECT quarterly_reports.*
6159
FROM quarterly_reports JOIN {self.__table__}
6260
ON quarterly_reports.company_id = {self.__table__}.id
63-
WHERE {self.__table__}.company_name = %s;
61+
WHERE {self.__table__}.name = %s;
6462
"""
6563
cursor.execute(sql_str, (company_name,))
6664
records = cursor.fetchall()
@@ -76,7 +74,7 @@ def get_company_quarterly_prices_pe(self, company_name, cursor):
7674
"""
7775
cursor.execute(sql_str, (company_name,))
7876
records = cursor.fetchall()
79-
return db.build_from_records(models.QuarterlyReport, records)
77+
return db.build_from_records(models.PricePE, records)
8078

8179
@classmethod
8280
def find_companies_quarterly_financials(self, sub_sector_name:str, financial_indicator:str, cursor):
@@ -89,31 +87,29 @@ def find_companies_quarterly_financials(self, sub_sector_name:str, financial_ind
8987
financial_indicator name, year, quarter], and their corresponding values stored in a list as
9088
the dictionary value.
9189
"""
92-
companies_quarterly_financials_json = self.to_company_quarterly_financials_json(sub_sector_name, financial_indicator, cursor)
90+
companies_quarterly_financials_json = self.to_all_companies_quarterly_financials_json(sub_sector_name, financial_indicator, cursor)
9391
single_financial_indicator_json = extract_single_financial_indicator(financial_indicator, companies_quarterly_financials_json)
9492
return single_financial_indicator_json
9593

9694
@classmethod
97-
def to_company_quarterly_financials_json(self, sub_sector_name, financial_indicator, cursor):
98-
company_names = MixinCompanyPricePE.get_all_company_names_in_sub_sector(self, sub_sector_name, cursor)
95+
def to_all_companies_quarterly_financials_json(self, sub_sector_name, financial_indicator, cursor):
96+
company_names = self.get_all_company_names_in_sub_sector(sub_sector_name, cursor)
9997
avg_quarterly_financials_dict = {}
10098
for company_name in company_names:
101-
avg_quarterly_financials_dict[company_name] = (MixinCompanyFinancials.
102-
to_quarterly_financials_json(self, company_name, cursor))
99+
avg_quarterly_financials_dict[company_name] = self.to_quarterly_financials_json(company_name, cursor)
103100
return avg_quarterly_financials_dict
104101

105102
@classmethod
106103
def find_company_quarterly_price_pe(self, sub_sector_name:str, financial_indicator:str, cursor):
107-
companies_quarterly_price_pe_json = self.to_company_quarterly_price_pe_json(sub_sector_name, financial_indicator, cursor)
104+
companies_quarterly_price_pe_json = self.to_all_companies_quarterly_price_pe_json(sub_sector_name, financial_indicator, cursor)
108105
single_financial_indicator_json = extract_single_financial_indicator(financial_indicator, companies_quarterly_price_pe_json)
109106
return single_financial_indicator_json
110107

111108
@classmethod
112-
def to_company_quarterly_price_pe_json(self, sub_sector_name, financial_indicator, cursor):
113-
company_names = MixinCompanyPricePE.get_all_company_names_in_sub_sector(self, sub_sector_name, cursor)
109+
def to_all_companies_quarterly_price_pe_json(self, sub_sector_name, financial_indicator, cursor):
110+
company_names = self.get_all_company_names_in_sub_sector(sub_sector_name, cursor)
114111
avg_quarterly_price_pe_dict = {}
115112
for company_name in company_names:
116-
avg_quarterly_price_pe_dict[company_name] = (MixinCompanyPricePE.
117-
to_quarterly_price_pe_json(self, company_name, cursor))
113+
avg_quarterly_price_pe_dict[company_name] = self.to_quarterly_price_pe_json(company_name, cursor)
118114
return avg_quarterly_price_pe_dict
119115

‎backend/api/src/models/price_pe.py‎

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,22 @@
33

44
class PricePE:
55
__table__ = 'prices_pe'
6-
columns = ['id', 'date', 'company_id', 'closing_price', 'price_earnings_ratio']
6+
columns = ['id', 'date', 'company_id', 'closing_price', 'price_earnings_ratio', 'year', 'quarter']
77

88
def __init__(self, **kwargs):
99
for key in kwargs.keys():
1010
if key not in self.columns:
1111
raise f"{key} not in {self.columns}"
1212
for k, v in kwargs.items():
1313
setattr(self, k, v)
14+
if hasattr(self, 'date'):
15+
self.year = self.date.year
16+
self.quarter = (self.date.month - 1) // 3 + 1
1417

1518
@classmethod
1619
def find_by_company_id(self, company_id, cursor):
1720
sql_str = f"""SELECT * FROM {self.__table__}
1821
WHERE company_id = %s;"""
1922
cursor.execute(sql_str, (company_id,))
2023
records = cursor.fetchall()
21-
return db.build_from_records(models.SubIndustry, records)
24+
return db.build_from_records(self, records)

0 commit comments

Comments
 (0)