99from tensorflow .keras .models import Model
1010from tensorflow .keras .layers import Input , LSTM , Dense , Dropout
1111from tensorflow .keras .optimizers import Adam
12+ import os
13+
1214
1315def clean_column_names (df ):
1416 new_cols = []
@@ -21,6 +23,7 @@ def clean_column_names(df):
2123 df .columns = new_cols
2224 return df
2325
26+
2427def fetch_stock_data (ticker , period , interval ):
2528 data = yf .download (tickers = ticker , period = period , interval = interval )
2629 data = data .reset_index ()
@@ -40,6 +43,7 @@ def fetch_stock_data(ticker, period, interval):
4043 data .reset_index (inplace = True )
4144 return data
4245
46+
4347def fetch_options_data (symbol , days_to_fetch ):
4448 all_data = []
4549 for i in range (days_to_fetch ):
@@ -113,7 +117,15 @@ def clean_and_merge(stock_df, options_df):
113117
114118 return merged
115119
116- def train_and_predict (merged_df , sequence_length = 10 , epochs = 30 ):
120+
121+ def train_and_predict (merged_df , sequence_length = 10 , epochs = 30 , use_transfer_learning = True ):
122+ """
123+ Train LSTM model with optional transfer learning.
124+
125+ If base weights exist, loads them and fine-tunes (5 epochs).
126+ Otherwise, trains from scratch (30 epochs).
127+ """
128+
117129 features = [
118130 'Open' ,'High' ,'Low' ,'Close' ,'Volume' ,'MA_5' ,'MA_10' ,'Volatility' ,
119131 'CE_openInterest' ,'PE_openInterest' ,'CE_changeinOpenInterest' ,'PE_changeinOpenInterest' ,
@@ -123,33 +135,111 @@ def train_and_predict(merged_df, sequence_length=10, epochs=30):
123135 'Close' ,'Return' ,'Volatility' ,'CE_openInterest' ,'PE_openInterest' ,
124136 'CE_changeinOpenInterest' ,'PE_changeinOpenInterest' ,'PCR'
125137 ]
138+
126139 df = merged_df .copy ()
127140 df .fillna (0 , inplace = True )
128141 X = df [features ].values
129142 y = df [targets ].values
143+
130144 scaler_X = MinMaxScaler ()
131145 scaler_y = MinMaxScaler ()
132146 X_scaled = scaler_X .fit_transform (X )
133147 y_scaled = scaler_y .fit_transform (y )
148+
134149 X_seq , y_seq = [], []
135150 for i in range (len (X_scaled ) - sequence_length ):
136151 X_seq .append (X_scaled [i :i + sequence_length ])
137152 y_seq .append (y_scaled [i + sequence_length ])
138153 X_seq = np .array (X_seq )
139154 y_seq = np .array (y_seq )
140- inputs = Input (shape = (X_seq .shape [1 ], X_seq .shape [2 ]))
141- x = LSTM (128 , return_sequences = True )(inputs )
142- x = LSTM (64 )(x )
143- x = Dropout (0.2 )(x )
144- outputs = Dense (y_seq .shape [1 ], activation = 'linear' )(x )
155+
156+ # Build model architecture
157+ inputs = Input (shape = (X_seq .shape [1 ], X_seq .shape [2 ]), name = 'input' )
158+ x = LSTM (128 , return_sequences = True , name = 'lstm_1' )(inputs )
159+ x = LSTM (64 , name = 'lstm_2' )(x )
160+ x = Dropout (0.2 , name = 'dropout_1' )(x )
161+ outputs = Dense (y_seq .shape [1 ], activation = 'linear' , name = 'dense_output' )(x )
162+
145163 model = Model (inputs , outputs )
164+
165+ # =================== TRANSFER LEARNING ===================
166+ base_weights_path = 'models/base_stock_weights.weights.h5'
167+
168+ if use_transfer_learning and os .path .exists (base_weights_path ):
169+ try :
170+ print ("\n " + "=" * 60 )
171+ print ("🔄 TRANSFER LEARNING ENABLED" )
172+ print ("=" * 60 )
173+ print ("📥 Loading pre-trained base model weights..." )
174+
175+ # Load pre-trained weights (by_name matches layer names)
176+ try :
177+ # Try Keras 3.x method (no by_name parameter)
178+ model .load_weights (base_weights_path , skip_mismatch = True )
179+ except TypeError :
180+ # Fallback for Keras 2.x
181+ model .load_weights (base_weights_path , by_name = True , skip_mismatch = True )
182+ print ("✅ Base weights loaded successfully!" )
183+
184+ # Freeze LSTM layers to preserve learned patterns
185+ model .get_layer ('lstm_1' ).trainable = False
186+ model .get_layer ('lstm_2' ).trainable = False
187+ print ("✅ Frozen LSTM layers (keeping general market patterns)" )
188+
189+ # Count trainable layers
190+ trainable_count = sum ([1 for layer in model .layers if layer .trainable ])
191+ print (f"✅ Training only { trainable_count } layers: Dropout + Dense" )
192+
193+ # Reduce epochs for fine-tuning
194+ epochs = 5
195+ print (f"✅ Reduced training epochs: { epochs } (instead of 30)" )
196+ print ("=" * 60 + "\n " )
197+
198+ print ("⚡ Expected training time: 15-30 seconds (10x faster!)\n " )
199+
200+ except Exception as e :
201+ print (f"\n ⚠️ Could not load base weights: { e } " )
202+ print ("⚠️ Falling back to training from scratch...\n " )
203+ # Make all layers trainable again
204+ for layer in model .layers :
205+ layer .trainable = True
206+ else :
207+ if not os .path .exists (base_weights_path ):
208+ print ("\n " + "=" * 60 )
209+ print ("ℹ️ BASE WEIGHTS NOT FOUND" )
210+ print ("=" * 60 )
211+ print (f"📁 Looking for: { base_weights_path } " )
212+ print ("💡 To enable transfer learning (10x speedup):" )
213+ print (" 1. Run: python Train_Base_Model.py" )
214+ print (" 2. Wait ~15 minutes (one-time setup)" )
215+ print (" 3. Enjoy 15-30 second predictions forever!" )
216+ print ("=" * 60 + "\n " )
217+
218+ print ("🔨 Training from scratch (this will take 2-5 minutes)...\n " )
219+ # =========================================================
220+
146221 model .compile (optimizer = Adam (learning_rate = 0.001 ), loss = 'mse' )
147- model .fit (X_seq , y_seq , epochs = epochs , batch_size = 32 , validation_split = 0.1 , shuffle = False , verbose = 1 )
222+
223+ model .fit (
224+ X_seq , y_seq ,
225+ epochs = epochs ,
226+ batch_size = 32 ,
227+ validation_split = 0.1 ,
228+ shuffle = False ,
229+ verbose = 1
230+ )
231+
232+ # Unfreeze all layers for prediction (if any were frozen)
233+ for layer in model .layers :
234+ layer .trainable = True
235+
148236 y_pred_scaled = model .predict (X_seq )
149237 y_pred = scaler_y .inverse_transform (y_pred_scaled )
150238 pred_df = pd .DataFrame (y_pred , columns = targets )
239+
151240 return pred_df
152241
242+
153243def summarize_predictions (pred_df ):
154244 summary_features = {}
155245 for col in pred_df .columns :
@@ -162,7 +252,6 @@ def summarize_predictions(pred_df):
162252 return summary_features
163253
164254
165-
166255def main_pipeline (ticker , symbol , period , interval , days_to_fetch ):
167256 stock_df = fetch_stock_data (ticker = ticker , period = period , interval = interval )
168257 options_df = fetch_options_data (symbol = symbol , days_to_fetch = days_to_fetch )
@@ -171,6 +260,7 @@ def main_pipeline(ticker, symbol, period, interval, days_to_fetch):
171260 summary_dict = summarize_predictions (pred_df )
172261 return json .dumps (summary_dict , indent = 4 )
173262
263+
174264if __name__ == "__main__" :
175265 import sys
176266
0 commit comments