Skip to content

Commit a148789

Browse files
authored
Merge pull request #49 from AKKI0511/analyze-quanttradeai-for-new-feature
feat: support secondary timeframe ingestion
2 parents 271f5ad + c724f9c commit a148789

File tree

7 files changed

+416
-98
lines changed

7 files changed

+416
-98
lines changed

config/model_config.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@ data:
55
symbols: ['AAPL', 'META', 'TSLA', 'JPM', 'AMZN']
66
start_date: '2015-01-01'
77
end_date: '2024-12-31'
8-
timeframe: '1d'
9-
cache_dir: 'data/raw'
8+
timeframe: '1d'
9+
secondary_timeframes:
10+
- '1h'
11+
- '30m'
12+
cache_dir: 'data/raw'
1013
cache_path: 'data/raw' # Directory to load/store cached OHLCV files
1114
cache_expiration_days: 7 # Refresh cache after this many days
1215
use_cache: true

docs/configuration.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ data:
2424
end_date: '2024-12-31'
2525
cache_dir: 'data/raw'
2626
cache_path: 'data/raw'
27+
secondary_timeframes:
28+
- '1h'
29+
- '30m'
2730
cache_expiration_days: 7
2831
use_cache: true
2932
refresh: false
@@ -37,6 +40,7 @@ data:
3740
- `symbols`: List of stock symbols to process
3841
- `start_date`/`end_date`: Data date range
3942
- `cache_dir`: Directory for cached data
43+
- `secondary_timeframes`: Optional list of higher-frequency bars to resample into the primary `timeframe` using OHLCV aggregations (`open→first`, `high→max`, `low→min`, `close→last`, `volume→sum`)
4044
- `use_cache`: Enable/disable caching
4145
- `refresh`: Force fresh data download
4246
- `max_workers`: Parallel processing workers

quanttradeai/data/loader.py

Lines changed: 133 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def __init__(
5252
self.symbols = data_cfg.symbols
5353
self.start_date = data_cfg.start_date
5454
self.end_date = data_cfg.end_date
55-
self.timeframe = data_cfg.timeframe or "1d"
55+
self.timeframe = data_cfg.timeframe or "1d"
56+
self.secondary_timeframes = data_cfg.secondary_timeframes or []
5657
# allow both legacy 'cache_dir' and new 'cache_path' keys
5758
self.cache_dir = data_cfg.cache_path or data_cfg.cache_dir or "data/raw"
5859
self.cache_expiration_days = data_cfg.cache_expiration_days
@@ -72,37 +73,66 @@ def _is_cache_valid(self, cache_file: str) -> bool:
7273

7374
def _fetch_single(self, symbol: str, refresh: bool) -> Optional[pd.DataFrame]:
7475
"""Fetch data for a single symbol and handle caching."""
75-
cache_file = os.path.join(
76-
self.cache_dir, f"{symbol}_{self.timeframe}_data.parquet"
77-
)
78-
try:
79-
if self.use_cache and not refresh and self._is_cache_valid(cache_file):
80-
logger.info(f"Loading cached data for {symbol} from {cache_file}")
81-
df = pd.read_parquet(cache_file)
82-
else:
83-
logger.info(f"Fetching data for {symbol}")
84-
df = self.data_source.fetch(
85-
symbol, self.start_date, self.end_date, self.timeframe
86-
)
87-
88-
if df is None or df.empty:
89-
logger.error(f"No data found for {symbol}")
90-
return None
91-
92-
if self.use_cache:
93-
os.makedirs(self.cache_dir, exist_ok=True)
94-
df.to_parquet(cache_file)
95-
logger.info(f"Cached data for {symbol} at {cache_file}")
96-
97-
missing_dates = self._check_missing_dates(df)
98-
if missing_dates:
99-
logger.warning(f"Missing dates for {symbol}: {len(missing_dates)} days")
100-
101-
logger.info(f"Successfully retrieved {len(df)} records for {symbol}")
102-
return df
103-
except Exception as e:
104-
logger.error(f"Error fetching data for {symbol}: {str(e)}")
105-
return None
76+
cache_file = os.path.join(
77+
self.cache_dir, f"{symbol}_{self.timeframe}_data.parquet"
78+
)
79+
try:
80+
df = self._load_timeframe_data(
81+
symbol=symbol,
82+
timeframe=self.timeframe,
83+
cache_file=cache_file,
84+
refresh=refresh,
85+
)
86+
87+
if df is None or df.empty:
88+
logger.error(f"No data found for {symbol}")
89+
return None
90+
91+
df = df.sort_index()
92+
93+
for secondary_tf in self.secondary_timeframes:
94+
secondary_cache = os.path.join(
95+
self.cache_dir, f"{symbol}_{secondary_tf}_data.parquet"
96+
)
97+
secondary_df = self._load_timeframe_data(
98+
symbol=symbol,
99+
timeframe=secondary_tf,
100+
cache_file=secondary_cache,
101+
refresh=refresh,
102+
)
103+
104+
if secondary_df is None or secondary_df.empty:
105+
logger.warning(
106+
"No data found for %s at secondary timeframe %s",
107+
symbol,
108+
secondary_tf,
109+
)
110+
continue
111+
112+
try:
113+
resampled = self._resample_secondary(
114+
secondary_df, df.index, secondary_tf
115+
)
116+
except ValueError as exc:
117+
logger.warning(
118+
"Skipping secondary timeframe %s for %s: %s",
119+
secondary_tf,
120+
symbol,
121+
exc,
122+
)
123+
continue
124+
125+
df = df.join(resampled, how="left")
126+
127+
missing_dates = self._check_missing_dates(df)
128+
if missing_dates:
129+
logger.warning(f"Missing dates for {symbol}: {len(missing_dates)} days")
130+
131+
logger.info(f"Successfully retrieved {len(df)} records for {symbol}")
132+
return df
133+
except Exception as e:
134+
logger.error(f"Error fetching data for {symbol}: {str(e)}")
135+
return None
106136

107137
def fetch_data(
108138
self, symbols: Optional[List[str]] = None, refresh: Optional[bool] = None
@@ -139,11 +169,77 @@ def fetch_data(
139169

140170
return data_dict
141171

142-
def _check_missing_dates(self, df: pd.DataFrame) -> List[datetime]:
143-
"""Check for missing trading days in the data."""
144-
all_dates = pd.date_range(start=df.index.min(), end=df.index.max(), freq="B")
145-
missing_dates = all_dates.difference(df.index)
146-
return list(missing_dates)
172+
def _check_missing_dates(self, df: pd.DataFrame) -> List[datetime]:
173+
"""Check for missing trading days in the data."""
174+
all_dates = pd.date_range(start=df.index.min(), end=df.index.max(), freq="B")
175+
missing_dates = all_dates.difference(df.index)
176+
return list(missing_dates)
177+
178+
def _load_timeframe_data(
179+
self, symbol: str, timeframe: str, cache_file: str, refresh: bool
180+
) -> Optional[pd.DataFrame]:
181+
"""Load data for a given timeframe from cache or datasource."""
182+
183+
if self.use_cache and not refresh and self._is_cache_valid(cache_file):
184+
logger.info(
185+
"Loading cached data for %s (%s) from %s", symbol, timeframe, cache_file
186+
)
187+
return pd.read_parquet(cache_file)
188+
189+
logger.info(f"Fetching data for {symbol} ({timeframe})")
190+
df = self.data_source.fetch(
191+
symbol, self.start_date, self.end_date, timeframe
192+
)
193+
194+
if df is None or df.empty:
195+
return None
196+
197+
if self.use_cache:
198+
os.makedirs(self.cache_dir, exist_ok=True)
199+
df.to_parquet(cache_file)
200+
logger.info(
201+
"Cached data for %s (%s) at %s", symbol, timeframe, cache_file
202+
)
203+
204+
return df
205+
206+
def _resample_secondary(
207+
self,
208+
df: pd.DataFrame,
209+
target_index: pd.Index,
210+
source_timeframe: str,
211+
) -> pd.DataFrame:
212+
"""Resample a secondary timeframe to the loader's primary timeframe."""
213+
214+
if not isinstance(df.index, pd.DatetimeIndex):
215+
df = df.copy()
216+
df.index = pd.to_datetime(df.index)
217+
218+
df = df.sort_index()
219+
220+
required_columns = {
221+
"Open": "first",
222+
"High": "max",
223+
"Low": "min",
224+
"Close": "last",
225+
"Volume": "sum",
226+
}
227+
228+
missing = [col for col in required_columns if col not in df.columns]
229+
if missing:
230+
raise ValueError(
231+
f"Secondary timeframe data missing required columns: {missing}"
232+
)
233+
234+
resampled = df.resample(self.timeframe).agg(required_columns)
235+
236+
renamed = {
237+
col: f"{col.lower()}_{source_timeframe}_{required_columns[col]}"
238+
for col in required_columns
239+
}
240+
resampled = resampled.rename(columns=renamed)
241+
resampled = resampled.reindex(target_index)
242+
return resampled
147243

148244
def validate_data(self, data_dict: Dict[str, pd.DataFrame]) -> bool:
149245
"""

quanttradeai/utils/config_schemas.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
from pydantic import BaseModel, Field
1414

1515

16-
class DataSection(BaseModel):
17-
symbols: List[str]
18-
start_date: str
19-
end_date: str
20-
timeframe: Optional[str] = "1d"
21-
cache_path: Optional[str] = None
22-
cache_dir: Optional[str] = None
16+
class DataSection(BaseModel):
17+
symbols: List[str]
18+
start_date: str
19+
end_date: str
20+
timeframe: Optional[str] = "1d"
21+
secondary_timeframes: Optional[List[str]] = None
22+
cache_path: Optional[str] = None
23+
cache_dir: Optional[str] = None
2324
cache_expiration_days: Optional[int] = None
2425
use_cache: Optional[bool] = True
2526
refresh: Optional[bool] = False

tests/data/test_loader.py

Lines changed: 100 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from unittest.mock import patch
2+
from unittest.mock import patch, call
33
import pandas as pd
44
import os
55
import yaml
@@ -30,21 +30,23 @@ def setUp(self):
3030
)
3131
self.df.to_parquet(os.path.join(self.cache_dir, "TEST_1d_data.parquet"))
3232

33-
def _write_config(self, expiration):
34-
config = {
35-
"data": {
36-
"symbols": ["TEST"],
37-
"start_date": "2020-01-01",
38-
"end_date": "2020-01-10",
39-
"cache_path": self.cache_dir,
40-
"timeframe": "1d",
41-
"cache_expiration_days": expiration,
42-
"use_cache": True,
43-
"refresh": False,
44-
}
45-
}
46-
with open(self.config_path, "w") as f:
47-
yaml.dump(config, f)
33+
def _write_config(self, expiration, secondary_timeframes=None):
34+
config = {
35+
"data": {
36+
"symbols": ["TEST"],
37+
"start_date": "2020-01-01",
38+
"end_date": "2020-01-10",
39+
"cache_path": self.cache_dir,
40+
"timeframe": "1d",
41+
"cache_expiration_days": expiration,
42+
"use_cache": True,
43+
"refresh": False,
44+
}
45+
}
46+
if secondary_timeframes is not None:
47+
config["data"]["secondary_timeframes"] = secondary_timeframes
48+
with open(self.config_path, "w") as f:
49+
yaml.dump(config, f)
4850

4951
def tearDown(self):
5052
shutil.rmtree(self.tmpdir)
@@ -78,20 +80,88 @@ def test_fetch_data_refreshes_cache(self, mock_fetch):
7880
pd.testing.assert_frame_equal(data_dict["TEST"], mock_history)
7981

8082
@patch("quanttradeai.data.datasource.YFinanceDataSource.fetch")
81-
def test_fetch_data_expired_cache(self, mock_fetch):
82-
self._write_config(expiration=0)
83-
84-
mock_history = pd.DataFrame(
85-
{"Open": [1], "High": [1], "Low": [1], "Close": [1], "Volume": [1]},
86-
index=pd.date_range("2020-01-01", periods=1),
87-
)
88-
mock_fetch.return_value = mock_history
89-
90-
loader = DataLoader(self.config_path)
91-
data_dict = loader.fetch_data()
92-
93-
mock_fetch.assert_called_once()
94-
pd.testing.assert_frame_equal(data_dict["TEST"], mock_history)
83+
def test_fetch_data_expired_cache(self, mock_fetch):
84+
self._write_config(expiration=0)
85+
86+
mock_history = pd.DataFrame(
87+
{"Open": [1], "High": [1], "Low": [1], "Close": [1], "Volume": [1]},
88+
index=pd.date_range("2020-01-01", periods=1),
89+
)
90+
mock_fetch.return_value = mock_history
91+
92+
loader = DataLoader(self.config_path)
93+
data_dict = loader.fetch_data()
94+
95+
mock_fetch.assert_called_once()
96+
pd.testing.assert_frame_equal(data_dict["TEST"], mock_history)
97+
98+
@patch("quanttradeai.data.datasource.YFinanceDataSource.fetch")
99+
def test_fetch_data_with_secondary_timeframes(self, mock_fetch):
100+
self._write_config(expiration=10, secondary_timeframes=["1h"])
101+
102+
primary_index = pd.date_range("2020-01-01", periods=2, freq="D")
103+
primary_df = pd.DataFrame(
104+
{
105+
"Open": [100.0, 110.0],
106+
"High": [101.0, 111.0],
107+
"Low": [99.0, 109.0],
108+
"Close": [100.5, 110.5],
109+
"Volume": [1000, 1100],
110+
},
111+
index=primary_index,
112+
)
113+
114+
hourly_index = pd.date_range("2020-01-01", periods=48, freq="h")
115+
hourly_df = pd.DataFrame(
116+
{
117+
"Open": range(48),
118+
"High": [value + 1 for value in range(48)],
119+
"Low": range(48),
120+
"Close": [value + 0.5 for value in range(48)],
121+
"Volume": [10] * 48,
122+
},
123+
index=hourly_index,
124+
)
125+
126+
mock_fetch.side_effect = [primary_df, hourly_df]
127+
128+
loader = DataLoader(self.config_path)
129+
data_dict = loader.fetch_data(refresh=True)
130+
131+
df = data_dict["TEST"]
132+
133+
expected_secondary = (
134+
hourly_df.resample("1D")
135+
.agg({
136+
"Open": "first",
137+
"High": "max",
138+
"Low": "min",
139+
"Close": "last",
140+
"Volume": "sum",
141+
})
142+
.rename(
143+
columns={
144+
"Open": "open_1h_first",
145+
"High": "high_1h_max",
146+
"Low": "low_1h_min",
147+
"Close": "close_1h_last",
148+
"Volume": "volume_1h_sum",
149+
}
150+
)
151+
.reindex(primary_index)
152+
)
153+
154+
for column in expected_secondary.columns:
155+
assert column in df.columns
156+
pd.testing.assert_series_equal(
157+
df[column], expected_secondary[column], check_names=True
158+
)
159+
160+
expected_calls = [
161+
call("TEST", "2020-01-01", "2020-01-10", "1d"),
162+
call("TEST", "2020-01-01", "2020-01-10", "1h"),
163+
]
164+
assert mock_fetch.call_args_list == expected_calls
95165

96166

97167
if __name__ == "__main__":

0 commit comments

Comments
 (0)