Skip to content

Commit b205de8

Browse files
authored
Merge pull request #17 from AKKI0511/codex/refactor-fetch_data-for-parallel-download
Enable parallel data loading
2 parents 2173546 + 7a31e4a commit b205de8

File tree

4 files changed

+114
-36
lines changed

4 files changed

+114
-36
lines changed

config/model_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ data:
1010
cache_expiration_days: 7 # Refresh cache after this many days
1111
use_cache: true
1212
refresh: false
13+
max_workers: 1
1314
test_start: '2025-01-01'
1415
test_end: '2025-01-31'
1516

src/data/loader.py

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datetime import datetime, timedelta
66
import yaml
77
from pydantic import ValidationError
8+
from concurrent.futures import ThreadPoolExecutor, as_completed
89

910
from utils.config_schemas import ModelConfigSchema
1011
import os
@@ -36,6 +37,7 @@ def __init__(self, config_path: str = "config/model_config.yaml"):
3637
self.cache_expiration_days = data_cfg.cache_expiration_days
3738
self.use_cache = data_cfg.use_cache
3839
self.default_refresh = data_cfg.refresh
40+
self.max_workers = data_cfg.max_workers or 1
3941

4042
def _is_cache_valid(self, cache_file: str) -> bool:
4143
"""Return True if the cache file exists and is not expired."""
@@ -46,6 +48,37 @@ def _is_cache_valid(self, cache_file: str) -> bool:
4648
file_time = datetime.fromtimestamp(os.path.getmtime(cache_file))
4749
return datetime.now() - file_time < timedelta(days=self.cache_expiration_days)
4850

51+
def _fetch_single(self, symbol: str, refresh: bool) -> Optional[pd.DataFrame]:
52+
"""Fetch data for a single symbol and handle caching."""
53+
cache_file = os.path.join(self.cache_dir, f"{symbol}_data.parquet")
54+
try:
55+
if self.use_cache and not refresh and self._is_cache_valid(cache_file):
56+
logger.info(f"Loading cached data for {symbol} from {cache_file}")
57+
df = pd.read_parquet(cache_file)
58+
else:
59+
logger.info(f"Fetching data for {symbol}")
60+
ticker = yf.Ticker(symbol)
61+
df = ticker.history(start=self.start_date, end=self.end_date)
62+
63+
if df.empty:
64+
logger.error(f"No data found for {symbol}")
65+
return None
66+
67+
if self.use_cache:
68+
os.makedirs(self.cache_dir, exist_ok=True)
69+
df.to_parquet(cache_file)
70+
logger.info(f"Cached data for {symbol} at {cache_file}")
71+
72+
missing_dates = self._check_missing_dates(df)
73+
if missing_dates:
74+
logger.warning(f"Missing dates for {symbol}: {len(missing_dates)} days")
75+
76+
logger.info(f"Successfully retrieved {len(df)} records for {symbol}")
77+
return df
78+
except Exception as e:
79+
logger.error(f"Error fetching data for {symbol}: {str(e)}")
80+
return None
81+
4982
def fetch_data(
5083
self, symbols: Optional[List[str]] = None, refresh: Optional[bool] = None
5184
) -> Dict[str, pd.DataFrame]:
@@ -61,42 +94,23 @@ def fetch_data(
6194
"""
6295
symbols = symbols or self.symbols
6396
refresh = self.default_refresh if refresh is None else refresh
64-
data_dict = {}
65-
66-
for symbol in symbols:
67-
cache_file = os.path.join(self.cache_dir, f"{symbol}_data.parquet")
68-
try:
69-
if self.use_cache and not refresh and self._is_cache_valid(cache_file):
70-
logger.info(f"Loading cached data for {symbol} from {cache_file}")
71-
df = pd.read_parquet(cache_file)
72-
else:
73-
logger.info(f"Fetching data for {symbol}")
74-
ticker = yf.Ticker(symbol)
75-
df = ticker.history(start=self.start_date, end=self.end_date)
76-
77-
if df.empty:
78-
logger.error(f"No data found for {symbol}")
79-
continue
80-
81-
# Save to cache if enabled
82-
if self.use_cache:
83-
os.makedirs(self.cache_dir, exist_ok=True)
84-
df.to_parquet(cache_file)
85-
logger.info(f"Cached data for {symbol} at {cache_file}")
86-
87-
# Validate data completeness
88-
missing_dates = self._check_missing_dates(df)
89-
if missing_dates:
90-
logger.warning(
91-
f"Missing dates for {symbol}: {len(missing_dates)} days"
92-
)
93-
94-
data_dict[symbol] = df
95-
logger.info(f"Successfully retrieved {len(df)} records for {symbol}")
96-
97-
except Exception as e:
98-
logger.error(f"Error fetching data for {symbol}: {str(e)}")
99-
continue
97+
data_dict: Dict[str, pd.DataFrame] = {}
98+
99+
if self.max_workers and self.max_workers > 1:
100+
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
101+
futures = {
102+
executor.submit(self._fetch_single, s, refresh): s for s in symbols
103+
}
104+
for future in as_completed(futures):
105+
symbol = futures[future]
106+
df = future.result()
107+
if df is not None:
108+
data_dict[symbol] = df
109+
else:
110+
for symbol in symbols:
111+
df = self._fetch_single(symbol, refresh)
112+
if df is not None:
113+
data_dict[symbol] = df
100114

101115
return data_dict
102116

src/utils/config_schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class DataSection(BaseModel):
1313
refresh: Optional[bool] = False
1414
test_start: Optional[str] = None
1515
test_end: Optional[str] = None
16+
max_workers: Optional[int] = 1
1617

1718

1819
class ModelConfigSchema(BaseModel):

tests/data/test_loader_extra.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
from unittest.mock import patch
23
import pandas as pd
34
import os
45
import yaml
@@ -80,5 +81,66 @@ def test_is_cache_valid_expired(self):
8081
self.assertFalse(self.loader._is_cache_valid(self.cache_file))
8182

8283

84+
class TestFetchDataParallel(unittest.TestCase):
85+
def setUp(self):
86+
self.tmpdir = tempfile.mkdtemp()
87+
self.cache_dir = os.path.join(self.tmpdir, "cache")
88+
os.makedirs(self.cache_dir, exist_ok=True)
89+
self.config_path = os.path.join(self.tmpdir, "config.yaml")
90+
config = {
91+
"data": {
92+
"symbols": ["AAA", "BBB"],
93+
"start_date": "2020-01-01",
94+
"end_date": "2020-01-02",
95+
"cache_path": self.cache_dir,
96+
"use_cache": False,
97+
"max_workers": 2,
98+
}
99+
}
100+
with open(self.config_path, "w") as f:
101+
yaml.dump(config, f)
102+
103+
def tearDown(self):
104+
shutil.rmtree(self.tmpdir)
105+
106+
@patch("data.loader.as_completed", side_effect=lambda fs: fs)
107+
@patch("data.loader.ThreadPoolExecutor")
108+
@patch("yfinance.Ticker")
109+
def test_parallel_execution(self, mock_ticker, mock_executor, _mock_ac):
110+
class DummyFuture:
111+
def __init__(self, result):
112+
self._result = result
113+
114+
def result(self):
115+
return self._result
116+
117+
class DummyExecutor:
118+
def __init__(self, max_workers=None):
119+
self.max_workers = max_workers
120+
121+
def __enter__(self):
122+
return self
123+
124+
def __exit__(self, exc_type, exc, tb):
125+
pass
126+
127+
def submit(self, fn, *args):
128+
return DummyFuture(fn(*args))
129+
130+
mock_executor.return_value = DummyExecutor(max_workers=2)
131+
mock_history = pd.DataFrame(
132+
{"Open": [1], "High": [1], "Low": [1], "Close": [1], "Volume": [1]},
133+
index=pd.date_range("2020-01-01", periods=1),
134+
)
135+
mock_ticker.return_value.history.return_value = mock_history
136+
137+
loader = DataLoader(self.config_path)
138+
data = loader.fetch_data()
139+
140+
mock_executor.assert_called_once_with(max_workers=2)
141+
self.assertIn("AAA", data)
142+
self.assertIn("BBB", data)
143+
144+
83145
if __name__ == "__main__":
84146
unittest.main()

0 commit comments

Comments
 (0)