@@ -62,13 +62,21 @@ def _engine(self) -> Engine:
6262 if not self .valid ():
6363 raise DatabaseFileNotFoundError ("Database file not found." )
6464 database_url = f"sqlite:///{ Path (self .database_path )} "
65+ try :
66+ upgrade (self .database_path )
67+ except Exception as e :
68+ print (
69+ f"Failed to upgrade the database at { self .database_path } . "
70+ f"Reason: { e } "
71+ "Skipping upgrade. "
72+ )
73+ logger .error (
74+ f"Failed to upgrade the database at { self .database_path } . "
75+ f"Reason: { e } "
76+ "Skipping upgrade. "
77+ )
6578 engine = create_engine (database_url )
6679 inspector = inspect (engine )
67- if "analysis" not in inspector .get_table_names ():
68- upgrade (self .database_path )
69-
70- engine = create_engine (database_url )
71- inspector = inspect (engine )
7280 if "subject" not in inspector .get_table_names ():
7381 logger .info (
7482 "Running tickermood upgrade to create the subject table in the database."
@@ -116,20 +124,33 @@ def read_symbols(self) -> List[str]:
116124 data = pd .read_sql_query (query , self ._engine )
117125 return data ["symbol" ].tolist ()
118126
127+ def copy_sec_to_analysis (self ) -> None :
128+ with Session (self ._engine ) as session :
129+ query = text (
130+ """UPDATE analysis
131+ SET
132+ occurrences = (SELECT secshareincrease.occurrences
133+ FROM secshareincrease WHERE secshareincrease.ticker = analysis.symbol),
134+ total_value = (SELECT secshareincrease.total_value
135+ FROM secshareincrease WHERE secshareincrease.ticker = analysis.symbol),
136+ total_increase = (SELECT secshareincrease.total_increase
137+ FROM secshareincrease WHERE secshareincrease.ticker = analysis.symbol)
138+ WHERE EXISTS (SELECT 1 FROM secshareincrease WHERE secshareincrease.ticker = analysis.symbol);"""
139+ )
140+ session .exec (query ) # type: ignore
141+ session .commit ()
142+
119143 def _read_analysis_data (
120144 self , columns : List [str ], symbols : Optional [List [str ]] = None
121145 ) -> pd .DataFrame :
122- columns__ = columns .copy ()
123146
124- columns_ = "," .join (columns__ )
147+ columns_ = "," .join (columns )
125148
126149 if symbols :
127150 symbols_str = "," .join ([f"'{ s } '" for s in symbols ])
128- query = f"""SELECT { columns_ } FROM analysis
129- LEFT JOIN secshareincrease ON secshareincrease.ticker=analysis.symbol WHERE symbol IN ({ symbols_str } )""" # noqa: S608
151+ query = f"""SELECT { columns_ } FROM analysis WHERE symbol IN ({ symbols_str } )""" # noqa: S608
130152 else :
131- query = f"""SELECT { columns_ } FROM analysis
132- LEFT JOIN secshareincrease ON secshareincrease.ticker=analysis.symbol""" # noqa: S608
153+ query = f"""SELECT { columns_ } FROM analysis""" # noqa: S608
133154 return pd .read_sql_query (query , self ._engine )
134155
135156 def _read_filter_query (self , query : str ) -> pd .DataFrame :
0 commit comments