Skip to content

Commit decd321

Browse files
committed
fix: address PR review feedback for sessions module
Improvements based on code review: - Replace hardcoded timezone offsets with proper pytz usage - Dynamically determine DST status from data timestamps - Remove DST testing special case logic that was incorrectly allowing RTH on weekends - Consolidate maintenance break logic into single method with product support - Make performance thresholds configurable via constructor parameters - Implement proper LRU cache with TTL and size limits - Fix test to use weekdays after DST transitions instead of Sundays Technical improvements: - Proper timezone conversion using pytz for DST handling - Configurable lazy evaluation threshold (default 100k rows) - Cache with 1-hour TTL and 1000 entry limit - Backward compatible cache implementation - Fixed edge case in DST transition test All 133 session tests passing, mypy and ruff checks clean.
1 parent 5d07e91 commit decd321

File tree

3 files changed

+128
-63
lines changed

3 files changed

+128
-63
lines changed

src/project_x_py/sessions/filtering.py

Lines changed: 118 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,80 @@
2222
class SessionFilterMixin:
2323
"""Mixin class providing session filtering capabilities."""
2424

25-
def __init__(self, config: SessionConfig | None = None):
26-
"""Initialize with optional session configuration."""
25+
# Configurable performance thresholds
26+
LAZY_EVAL_THRESHOLD = 100_000 # Rows before using lazy evaluation
27+
CACHE_MAX_SIZE = 1000 # Maximum cache entries
28+
CACHE_TTL_SECONDS = 3600 # Cache time-to-live in seconds
29+
30+
def __init__(
31+
self,
32+
config: SessionConfig | None = None,
33+
lazy_eval_threshold: int | None = None,
34+
cache_max_size: int | None = None,
35+
cache_ttl: int | None = None,
36+
):
37+
"""Initialize with optional session configuration and performance settings.
38+
39+
Args:
40+
config: Session configuration
41+
lazy_eval_threshold: Number of rows before using lazy evaluation
42+
cache_max_size: Maximum number of cache entries
43+
cache_ttl: Cache time-to-live in seconds
44+
"""
2745
self.config = config or SessionConfig()
2846
self._session_boundary_cache: dict[str, Any] = {}
47+
self._cache_timestamps: dict[str, float] = {}
48+
49+
# Allow overriding performance thresholds
50+
self.lazy_eval_threshold = lazy_eval_threshold or self.LAZY_EVAL_THRESHOLD
51+
self.cache_max_size = cache_max_size or self.CACHE_MAX_SIZE
52+
self.cache_ttl = cache_ttl or self.CACHE_TTL_SECONDS
2953

3054
def _get_cached_session_boundaries(
3155
self, data_hash: str, product: str, session_type: str
3256
) -> tuple[list[int], list[int]]:
33-
"""Get cached session boundaries for performance optimization."""
57+
"""Get cached session boundaries for performance optimization with TTL and size limits."""
58+
import time
59+
3460
cache_key = f"{data_hash}_{product}_{session_type}"
61+
current_time = time.time()
62+
63+
# Check if cached result exists and is still valid
3564
if cache_key in self._session_boundary_cache:
36-
cached_result = self._session_boundary_cache[cache_key]
37-
if isinstance(cached_result, tuple) and len(cached_result) == 2:
38-
return cached_result
65+
# Check TTL (backward compatible - if no timestamp, treat as valid)
66+
if cache_key in self._cache_timestamps:
67+
cache_age = current_time - self._cache_timestamps[cache_key]
68+
if cache_age < self.cache_ttl:
69+
cached_result = self._session_boundary_cache[cache_key]
70+
if isinstance(cached_result, tuple) and len(cached_result) == 2:
71+
return cached_result
72+
else:
73+
# Expired - remove from cache
74+
del self._session_boundary_cache[cache_key]
75+
del self._cache_timestamps[cache_key]
76+
else:
77+
# No timestamp entry (backward compatibility) - treat as valid
78+
cached_result = self._session_boundary_cache[cache_key]
79+
if isinstance(cached_result, tuple) and len(cached_result) == 2:
80+
# Add timestamp for future TTL checks
81+
self._cache_timestamps[cache_key] = current_time
82+
return cached_result
83+
84+
# Enforce cache size limit with LRU eviction
85+
if (
86+
len(self._session_boundary_cache) >= self.cache_max_size
87+
and self._cache_timestamps
88+
):
89+
oldest_key = min(
90+
self._cache_timestamps.keys(), key=lambda k: self._cache_timestamps[k]
91+
)
92+
del self._session_boundary_cache[oldest_key]
93+
del self._cache_timestamps[oldest_key]
3994

4095
# Calculate and cache boundaries (simplified implementation)
4196
boundaries: tuple[list[int], list[int]] = ([], [])
4297
self._session_boundary_cache[cache_key] = boundaries
98+
self._cache_timestamps[cache_key] = current_time
4399
return boundaries
44100

45101
def _use_lazy_evaluation(self, data: pl.DataFrame) -> pl.LazyFrame:
@@ -48,8 +104,8 @@ def _use_lazy_evaluation(self, data: pl.DataFrame) -> pl.LazyFrame:
48104

49105
def _optimize_filtering(self, data: pl.DataFrame) -> pl.DataFrame:
50106
"""Apply optimized filtering strategies for large datasets."""
51-
# For large datasets (>100k rows), use lazy evaluation
52-
if len(data) > 100_000:
107+
# Use configurable threshold for lazy evaluation
108+
if len(data) > self.lazy_eval_threshold:
53109
lazy_df = self._use_lazy_evaluation(data)
54110
# Would implement optimized lazy operations here
55111
return lazy_df.collect()
@@ -148,14 +204,33 @@ def _filter_rth_hours(
148204
self, data: pl.DataFrame, session_times: SessionTimes
149205
) -> pl.DataFrame:
150206
"""Filter data to RTH hours only."""
151-
# Convert to market timezone and filter by time
152-
# This is a simplified implementation for testing
207+
# Convert session times from ET to UTC for filtering
208+
# This properly handles DST transitions
209+
from datetime import UTC
153210

154-
# For ES: RTH is 9:30 AM - 4:00 PM ET
155-
# In UTC: 14:30 - 21:00 (standard time)
211+
import pytz
156212

157-
# Calculate UTC hours for RTH session times
158-
et_to_utc_offset = 5 # Standard time offset
213+
# Get market timezone
214+
et_tz = pytz.timezone("America/New_York")
215+
216+
# Get a sample timestamp from data to determine DST status
217+
if not data.is_empty():
218+
sample_ts = data["timestamp"][0]
219+
if sample_ts.tzinfo is None:
220+
# Assume UTC if no timezone
221+
sample_ts = sample_ts.replace(tzinfo=UTC)
222+
223+
# Convert to ET to check DST
224+
et_time = sample_ts.astimezone(et_tz)
225+
is_dst = bool(et_time.dst())
226+
227+
# Calculate proper UTC offset
228+
et_to_utc_offset = 4 if is_dst else 5 # EDT = UTC-4, EST = UTC-5
229+
else:
230+
# Default to standard time if no data
231+
et_to_utc_offset = 5
232+
233+
# Convert session times to UTC hours
159234
rth_start_hour = session_times.rth_start.hour + et_to_utc_offset
160235
rth_start_min = session_times.rth_start.minute
161236
rth_end_hour = session_times.rth_end.hour + et_to_utc_offset
@@ -188,6 +263,9 @@ def _filter_eth_hours(self, data: pl.DataFrame, product: str) -> pl.DataFrame:
188263
"""Filter data to ETH hours excluding maintenance breaks."""
189264
# ETH excludes maintenance breaks which vary by product
190265
# Most US futures: maintenance break 5:00 PM - 6:00 PM ET daily
266+
from datetime import UTC
267+
268+
import pytz
191269

192270
# Get maintenance break times for product
193271
maintenance_breaks = self._get_maintenance_breaks(product)
@@ -196,13 +274,25 @@ def _filter_eth_hours(self, data: pl.DataFrame, product: str) -> pl.DataFrame:
196274
# No maintenance breaks for this product - return all data
197275
return data
198276

277+
# Get market timezone
278+
et_tz = pytz.timezone("America/New_York")
279+
280+
# Determine DST status from sample timestamp
281+
if not data.is_empty():
282+
sample_ts = data["timestamp"][0]
283+
if sample_ts.tzinfo is None:
284+
sample_ts = sample_ts.replace(tzinfo=UTC)
285+
et_time = sample_ts.astimezone(et_tz)
286+
is_dst = bool(et_time.dst())
287+
et_to_utc_offset = 4 if is_dst else 5 # EDT = UTC-4, EST = UTC-5
288+
else:
289+
et_to_utc_offset = 5 # Default to standard time
290+
199291
# Start with all data and exclude maintenance periods
200292
filtered_conditions = []
201293

202294
for break_start, break_end in maintenance_breaks:
203295
# Convert ET maintenance times to UTC for filtering
204-
et_to_utc_offset = 5 # Standard time offset (need to handle DST properly)
205-
206296
break_start_hour = break_start.hour + et_to_utc_offset
207297
break_start_min = break_start.minute
208298
break_end_hour = break_end.hour + et_to_utc_offset
@@ -311,7 +401,7 @@ def is_in_session(
311401
if self._is_weekend_outside_eth(timestamp, market_time, session_type):
312402
return False
313403

314-
if self._is_maintenance_break(market_time.time()):
404+
if self._is_maintenance_break(market_time.time(), product):
315405
return False
316406

317407
# Apply session-specific logic
@@ -358,59 +448,33 @@ def _is_market_holiday(self, date: date) -> bool:
358448
date.month == 12 and date.day == 31
359449
)
360450

361-
def _is_dst_transition_date(self, date: date) -> bool:
362-
"""Check if the given date is a DST transition date."""
363-
# DST transitions in the US:
364-
# - Spring forward: Second Sunday in March
365-
# - Fall back: First Sunday in November
366-
367-
# Spring DST transition (second Sunday in March)
368-
if date.month == 3:
369-
# Find second Sunday in March
370-
first_day = date.replace(day=1)
371-
days_to_first_sunday = (6 - first_day.weekday()) % 7
372-
if days_to_first_sunday == 0:
373-
days_to_first_sunday = (
374-
7 # If March 1st is Sunday, first Sunday is March 8th
375-
)
376-
first_sunday = first_day.day + days_to_first_sunday
377-
second_sunday = first_sunday + 7
378-
return date.day == second_sunday
379-
380-
# Fall DST transition (first Sunday in November)
381-
elif date.month == 11:
382-
# Find first Sunday in November
383-
first_day = date.replace(day=1)
384-
days_to_first_sunday = (6 - first_day.weekday()) % 7
385-
first_sunday = first_day.day + days_to_first_sunday
386-
return date.day == first_sunday
387-
388-
return False
389-
390451
def _is_weekend_outside_eth(
391452
self,
392-
timestamp: datetime | str, # Keep for API compatibility
453+
timestamp: datetime | str,
393454
market_time: datetime,
394455
session_type: SessionType,
395456
) -> bool:
396457
"""Check if it's weekend outside of ETH trading hours."""
397458
if market_time.weekday() < 5: # Weekday
398459
return False
399460

400-
# Check for DST transition dates - these are special cases where RTH might be valid
401-
if self._is_dst_transition_date(market_time.date()):
402-
return False # Allow RTH during DST transitions for testing
403-
404461
# Weekend - check for Sunday evening ETH exception
405462
return not (
406463
market_time.weekday() == 6
407464
and market_time.hour >= 18
408465
and session_type == SessionType.ETH
409466
)
410467

411-
def _is_maintenance_break(self, current_time: time) -> bool:
412-
"""Check if current time is during maintenance break."""
413-
return time(17, 0) <= current_time < time(18, 0)
468+
def _is_maintenance_break(self, current_time: time, product: str = "ES") -> bool:
469+
"""Check if current time is during maintenance break for the given product."""
470+
maintenance_breaks = self._get_maintenance_breaks(product)
471+
472+
for break_start, break_end in maintenance_breaks:
473+
# Check if current time falls within any maintenance break
474+
if break_start <= current_time < break_end:
475+
return True
476+
477+
return False
414478

415479
def _check_session_hours(
416480
self, session_type: SessionType, session_times: SessionTimes, current_time: time

tests/unit/test_session_filter.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -685,16 +685,17 @@ def session_filter(self):
685685

686686
def test_daylight_saving_transitions(self, session_filter):
687687
"""Test session filtering during DST transitions."""
688-
# Spring forward: 2024-03-10 2:00 AM -> 3:00 AM ET
689-
# Fall back: 2024-11-03 2:00 AM -> 1:00 AM ET
688+
# Spring forward: 2024-03-10 2:00 AM -> 3:00 AM ET (Sunday)
689+
# Fall back: 2024-11-03 2:00 AM -> 1:00 AM ET (Sunday)
690+
# Test on the Monday after DST transitions when markets are open
690691

691-
# Test during spring DST transition
692-
spring_forward = datetime(2024, 3, 10, 15, 0, tzinfo=timezone.utc) # Should be RTH
693-
assert session_filter.is_in_session(spring_forward, SessionType.RTH, "ES") is True
692+
# Monday after spring DST transition (March 11, 2024)
693+
spring_monday = datetime(2024, 3, 11, 15, 0, tzinfo=timezone.utc) # 11:00 AM EDT - Should be RTH
694+
assert session_filter.is_in_session(spring_monday, SessionType.RTH, "ES") is True
694695

695-
# Test during fall DST transition
696-
fall_back = datetime(2024, 11, 3, 15, 0, tzinfo=timezone.utc) # Should be RTH
697-
assert session_filter.is_in_session(fall_back, SessionType.RTH, "ES") is True
696+
# Monday after fall DST transition (November 4, 2024)
697+
fall_monday = datetime(2024, 11, 4, 15, 0, tzinfo=timezone.utc) # 10:00 AM EST - Should be RTH
698+
assert session_filter.is_in_session(fall_monday, SessionType.RTH, "ES") is True
698699

699700
def test_leap_second_handling(self, session_filter):
700701
"""Test handling of leap seconds (rare edge case)."""

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)