55from datetime import datetime , timedelta
66import yaml
77from pydantic import ValidationError
8+ from concurrent .futures import ThreadPoolExecutor , as_completed
89
910from utils .config_schemas import ModelConfigSchema
1011import os
@@ -36,6 +37,7 @@ def __init__(self, config_path: str = "config/model_config.yaml"):
3637 self .cache_expiration_days = data_cfg .cache_expiration_days
3738 self .use_cache = data_cfg .use_cache
3839 self .default_refresh = data_cfg .refresh
40+ self .max_workers = data_cfg .max_workers or 1
3941
4042 def _is_cache_valid (self , cache_file : str ) -> bool :
4143 """Return True if the cache file exists and is not expired."""
@@ -46,6 +48,37 @@ def _is_cache_valid(self, cache_file: str) -> bool:
4648 file_time = datetime .fromtimestamp (os .path .getmtime (cache_file ))
4749 return datetime .now () - file_time < timedelta (days = self .cache_expiration_days )
4850
51+ def _fetch_single (self , symbol : str , refresh : bool ) -> Optional [pd .DataFrame ]:
52+ """Fetch data for a single symbol and handle caching."""
53+ cache_file = os .path .join (self .cache_dir , f"{ symbol } _data.parquet" )
54+ try :
55+ if self .use_cache and not refresh and self ._is_cache_valid (cache_file ):
56+ logger .info (f"Loading cached data for { symbol } from { cache_file } " )
57+ df = pd .read_parquet (cache_file )
58+ else :
59+ logger .info (f"Fetching data for { symbol } " )
60+ ticker = yf .Ticker (symbol )
61+ df = ticker .history (start = self .start_date , end = self .end_date )
62+
63+ if df .empty :
64+ logger .error (f"No data found for { symbol } " )
65+ return None
66+
67+ if self .use_cache :
68+ os .makedirs (self .cache_dir , exist_ok = True )
69+ df .to_parquet (cache_file )
70+ logger .info (f"Cached data for { symbol } at { cache_file } " )
71+
72+ missing_dates = self ._check_missing_dates (df )
73+ if missing_dates :
74+ logger .warning (f"Missing dates for { symbol } : { len (missing_dates )} days" )
75+
76+ logger .info (f"Successfully retrieved { len (df )} records for { symbol } " )
77+ return df
78+ except Exception as e :
79+ logger .error (f"Error fetching data for { symbol } : { str (e )} " )
80+ return None
81+
4982 def fetch_data (
5083 self , symbols : Optional [List [str ]] = None , refresh : Optional [bool ] = None
5184 ) -> Dict [str , pd .DataFrame ]:
@@ -61,42 +94,23 @@ def fetch_data(
6194 """
6295 symbols = symbols or self .symbols
6396 refresh = self .default_refresh if refresh is None else refresh
64- data_dict = {}
65-
66- for symbol in symbols :
67- cache_file = os .path .join (self .cache_dir , f"{ symbol } _data.parquet" )
68- try :
69- if self .use_cache and not refresh and self ._is_cache_valid (cache_file ):
70- logger .info (f"Loading cached data for { symbol } from { cache_file } " )
71- df = pd .read_parquet (cache_file )
72- else :
73- logger .info (f"Fetching data for { symbol } " )
74- ticker = yf .Ticker (symbol )
75- df = ticker .history (start = self .start_date , end = self .end_date )
76-
77- if df .empty :
78- logger .error (f"No data found for { symbol } " )
79- continue
80-
81- # Save to cache if enabled
82- if self .use_cache :
83- os .makedirs (self .cache_dir , exist_ok = True )
84- df .to_parquet (cache_file )
85- logger .info (f"Cached data for { symbol } at { cache_file } " )
86-
87- # Validate data completeness
88- missing_dates = self ._check_missing_dates (df )
89- if missing_dates :
90- logger .warning (
91- f"Missing dates for { symbol } : { len (missing_dates )} days"
92- )
93-
94- data_dict [symbol ] = df
95- logger .info (f"Successfully retrieved { len (df )} records for { symbol } " )
96-
97- except Exception as e :
98- logger .error (f"Error fetching data for { symbol } : { str (e )} " )
99- continue
97+ data_dict : Dict [str , pd .DataFrame ] = {}
98+
99+ if self .max_workers and self .max_workers > 1 :
100+ with ThreadPoolExecutor (max_workers = self .max_workers ) as executor :
101+ futures = {
102+ executor .submit (self ._fetch_single , s , refresh ): s for s in symbols
103+ }
104+ for future in as_completed (futures ):
105+ symbol = futures [future ]
106+ df = future .result ()
107+ if df is not None :
108+ data_dict [symbol ] = df
109+ else :
110+ for symbol in symbols :
111+ df = self ._fetch_single (symbol , refresh )
112+ if df is not None :
113+ data_dict [symbol ] = df
100114
101115 return data_dict
102116
0 commit comments