Skip to content

Commit 408f40c

Browse files
committed
Added test cases for datetime range based feature retrieval in Ray
Signed-off-by: Aniket Paluskar <[email protected]>
1 parent 66dcf65 commit 408f40c

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

sdk/python/feast/infra/offline_stores/contrib/ray_offline_store/tests/test_ray_integration.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from datetime import timedelta
2+
13
import pandas as pd
24
import pytest
35

@@ -144,3 +146,126 @@ def test_ray_offline_store_persist(environment, universal_data_sources):
144146
import os
145147

146148
assert os.path.exists(saved_path)
149+
150+
151+
@pytest.mark.integration
152+
@pytest.mark.universal_offline_stores
153+
def test_ray_offline_store_non_entity_mode_basic(environment, universal_data_sources):
154+
"""Test historical features retrieval without entity_df (non-entity mode).
155+
156+
This tests the basic functionality where entity_df=None and start_date/end_date
157+
are provided to retrieve all features within the time range.
158+
"""
159+
store = environment.feature_store
160+
161+
(entities, datasets, data_sources) = universal_data_sources
162+
feature_views = construct_universal_feature_views(data_sources)
163+
164+
store.apply(
165+
[
166+
driver(),
167+
feature_views.driver,
168+
]
169+
)
170+
171+
# Use the environment's start and end dates for the query
172+
start_date = environment.start_date
173+
end_date = environment.end_date
174+
175+
# Non-entity mode: entity_df=None with start_date and end_date
176+
result_df = store.get_historical_features(
177+
entity_df=None,
178+
features=[
179+
"driver_stats:conv_rate",
180+
"driver_stats:acc_rate",
181+
"driver_stats:avg_daily_trips",
182+
],
183+
full_feature_names=False,
184+
start_date=start_date,
185+
end_date=end_date,
186+
).to_df()
187+
188+
# Verify data was retrieved
189+
assert len(result_df) > 0, "Non-entity mode should return data"
190+
assert "conv_rate" in result_df.columns
191+
assert "acc_rate" in result_df.columns
192+
assert "avg_daily_trips" in result_df.columns
193+
assert "event_timestamp" in result_df.columns
194+
assert "driver_id" in result_df.columns
195+
196+
# Verify timestamps are within the requested range
197+
result_df["event_timestamp"] = pd.to_datetime(
198+
result_df["event_timestamp"], utc=True
199+
)
200+
assert (result_df["event_timestamp"] >= start_date).all()
201+
assert (result_df["event_timestamp"] <= end_date).all()
202+
203+
204+
@pytest.mark.integration
205+
@pytest.mark.universal_offline_stores
206+
def test_ray_offline_store_non_entity_mode_preserves_multiple_timestamps(
207+
environment, universal_data_sources
208+
):
209+
"""Test that non-entity mode preserves multiple transactions per entity ID.
210+
211+
This is a regression test for the fix that ensures distinct (entity_key, event_timestamp)
212+
combinations are preserved, not just distinct entity keys. This is critical for
213+
proper point-in-time joins when an entity has multiple transactions.
214+
"""
215+
store = environment.feature_store
216+
217+
(entities, datasets, data_sources) = universal_data_sources
218+
feature_views = construct_universal_feature_views(data_sources)
219+
220+
store.apply(
221+
[
222+
driver(),
223+
feature_views.driver,
224+
]
225+
)
226+
227+
now = _utc_now()
228+
ts1 = pd.Timestamp(now - timedelta(hours=2)).round("ms")
229+
ts2 = pd.Timestamp(now - timedelta(hours=1)).round("ms")
230+
ts3 = pd.Timestamp(now).round("ms")
231+
232+
# Write data with multiple timestamps for the same entity (driver_id=9001)
233+
df_to_write = pd.DataFrame.from_dict(
234+
{
235+
"event_timestamp": [ts1, ts2, ts3],
236+
"driver_id": [9001, 9001, 9001], # Same entity, different timestamps
237+
"conv_rate": [0.1, 0.2, 0.3],
238+
"acc_rate": [0.9, 0.8, 0.7],
239+
"avg_daily_trips": [10, 20, 30],
240+
"created": [ts1, ts2, ts3],
241+
},
242+
)
243+
244+
store.write_to_offline_store(
245+
feature_views.driver.name, df_to_write, allow_registry_cache=False
246+
)
247+
248+
# Query without entity_df - should get all 3 rows for driver_id=9001
249+
result_df = store.get_historical_features(
250+
entity_df=None,
251+
features=[
252+
"driver_stats:conv_rate",
253+
"driver_stats:acc_rate",
254+
],
255+
full_feature_names=False,
256+
start_date=ts1 - timedelta(minutes=1),
257+
end_date=ts3 + timedelta(minutes=1),
258+
).to_df()
259+
260+
# Filter to just our test entity
261+
result_df = result_df[result_df["driver_id"] == 9001]
262+
263+
# Verify we got all 3 rows with different timestamps (not just 1 row)
264+
assert len(result_df) == 3, (
265+
f"Expected 3 rows for driver_id=9001 (one per timestamp), got {len(result_df)}"
266+
)
267+
268+
# Verify the feature values are correct for each timestamp
269+
result_df = result_df.sort_values("event_timestamp").reset_index(drop=True)
270+
assert list(result_df["conv_rate"]) == [0.1, 0.2, 0.3]
271+
assert list(result_df["acc_rate"]) == [0.9, 0.8, 0.7]

0 commit comments

Comments
 (0)