Skip to content
6 changes: 5 additions & 1 deletion merlin/systems/dag/ops/feast.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime
from typing import List

import numpy as np
Expand Down Expand Up @@ -82,7 +83,7 @@ def from_feature_view(

output_schema = Schema([])
for feature in feature_view.features:
feature_dtype, is_list, is_ragged = feast_2_numpy[feature.dtype]
feature_dtype, is_list, is_ragged = feast_2_numpy[feature.dtype.to_value_type()]

if is_list:
mh_features.append(feature.name)
Expand Down Expand Up @@ -165,6 +166,9 @@ def __init__(
self.output_prefix = output_prefix

self.store = FeatureStore(repo_path=repo_path)
# add feature view to the online store
self.store.materialize_incremental(datetime.now(), feature_views=[self.entity_view])

super().__init__()

def __getstate__(self):
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ testbook==0.4.2

# packages necessary to run tests and push PRs
tritonclient
feast==0.18.1
feast==0.31
xgboost==1.6.2
implicit==0.6.0

Expand Down
148 changes: 148 additions & 0 deletions tests/integration/feast/test_int_feast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from datetime import datetime

import numpy as np
import pytest

from merlin.core.dispatch import make_df
from merlin.schema import ColumnSchema, Schema
from merlin.systems.dag import Ensemble
from merlin.systems.dag.ops.feast import QueryFeast # noqa
from merlin.table import Device, TensorTable # noqa

feast = pytest.importorskip("feast") # noqa


def test_feast_integration(tmpdir):
project_name = "test"
os.system(f"cd {tmpdir} && feast init {project_name}")
feast_repo = os.path.join(tmpdir, f"{project_name}")
feature_repo_path = os.path.join(feast_repo, "feature_repo/")
if os.path.exists(f"{feature_repo_path}/example_repo.py"):
os.remove(f"{feature_repo_path}/example_repo.py")
if os.path.exists(f"{feature_repo_path}/data/driver_stats.parquet"):
os.remove(f"{feature_repo_path}/data/driver_stats.parquet")
df_path = os.path.join(feature_repo_path, "data/", "item_features.parquet")
feat_file_path = os.path.join(feature_repo_path, "item_features.py")

item_features = make_df(
{
"item_id": [1, 2, 3, 4, 5, 6, 7, 8, 9],
"item_id_raw": [1, 2, 3, 4, 5, 6, 7, 8, 9],
"item_category": [
[1, 11],
[2, 12],
[3, 13],
[4, 14],
[5, 15],
[6, 16],
[7, 17],
[8, 18],
[9, 19],
],
"item_brand": [1, 2, 3, 4, 5, 6, 7, 8, 9],
}
)
item_features = TensorTable.from_df(item_features).to_df()
item_features["datetime"] = datetime.now()
item_features["datetime"] = item_features["datetime"].astype("datetime64[ns]")
item_features["created"] = datetime.now()
item_features["created"] = item_features["created"].astype("datetime64[ns]")

item_features.to_parquet(df_path)

with open(feat_file_path, "w", encoding="utf-8") as file:
file.write(
f"""
from datetime import timedelta
from feast import Entity, Field, FeatureView, ValueType
from feast.types import Int64, Array
from feast.infra.offline_stores.file_source import FileSource

item_features = FileSource(
path="{df_path}",
timestamp_field="datetime",
created_timestamp_column="created",
)

item = Entity(name="item_id", value_type=ValueType.INT64, join_keys=["item_id"],)

item_features_view = FeatureView(
name="item_features",
entities=[item],
ttl=timedelta(0),
schema=[
Field(name="item_category", dtype=Array(Int64)),
Field(name="item_brand", dtype=Int64),
Field(name="item_id_raw", dtype=Int64),
],
online=True,
source=item_features,
tags=dict(),
)
"""
)

os.system(
f"cd {feature_repo_path} && "
"feast apply && "
'CURRENT_TIME=$(date -u +"%Y-%m-%dT%H:%M:%S") && '
"feast materialize 1995-01-01T01:01:01 $CURRENT_TIME"
)

feature_store = feast.FeatureStore(feature_repo_path)

# check the information is loaded and correctly querying
feature_refs = [
"item_features:item_id_raw",
"item_features:item_category",
"item_features:item_brand",
]
feat_df = feature_store.get_historical_features(
features=feature_refs,
entity_df=make_df({"item_id": [1], "event_timestamp": [datetime.now()]}, device="cpu"),
).to_df()
assert all(feat_df["item_id_raw"] == 1)
# feature_store.write_to_online_store("item_features", item_features)
# create and run ensemble with feast operator
request_schema = Schema([ColumnSchema("item_id", dtype=np.int64)])
graph = ["item_id"] >> QueryFeast.from_feature_view(
store=feature_store,
view="item_features",
column="item_id",
output_prefix="item",
include_id=True,
)
ensemble = Ensemble(graph, request_schema)
result = ensemble.transform(TensorTable.from_df(make_df({"item_id": [1, 2]})))
columns = ["item_id_raw", "item_brand", "item_category"]
if result.device == Device.GPU:
for column in columns:
if column == "item_category":
assert (
result.to_df()[column]._column.leaves()
== item_features.iloc[0:2][column]._column.leaves()
).all()
assert (
result.to_df()[column]._column.offsets
== item_features.iloc[0:2][column]._column.offsets
).all()
else:
assert (result.to_df()[column] == item_features.iloc[0:2][column]).all()
else:
assert result.to_df()[columns].equals(item_features.iloc[0:2][columns])
51 changes: 34 additions & 17 deletions tests/unit/systems/ops/feast/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,42 @@ def test_feast_from_feature_view(tmpdir):
MagicMock(side_effect=QueryFeast),
) as qf_init:
input_source = feast.FileSource(
path=tmpdir,
path=str(tmpdir),
event_timestamp_column="datetime",
created_timestamp_column="created",
)
item_id = feast.Entity(
name="item_id",
value_type=feast.ValueType.INT32,
join_keys=["item_id"],
)
feature_view = feast.FeatureView(
name="item_features",
entities=["item_id"],
entities=[item_id],
ttl=timedelta(seconds=100),
features=[
feast.Feature(name="int_feature", dtype=feast.ValueType.INT32),
feast.Feature(name="float_feature", dtype=feast.ValueType.FLOAT),
feast.Feature(name="int_list_feature", dtype=feast.ValueType.INT32_LIST),
feast.Feature(name="float_list_feature", dtype=feast.ValueType.FLOAT_LIST),
schema=[
feast.Field(name="int_feature", dtype=feast.types.Int32),
feast.Field(name="float_feature", dtype=feast.types.Float32),
feast.Field(name="int_list_feature", dtype=feast.types.Array(feast.types.Int32)),
feast.Field(
name="float_list_feature", dtype=feast.types.Array(feast.types.Float32)
),
],
online=True,
input=input_source,
source=input_source,
tags={},
)
fs = feast.FeatureStore("repo_path")
fs.repo_path = "repo_path"
fs._registry = feast.feature_store.Registry(None, None)
fs.list_entities = MagicMock(
return_value=[feast.Entity(name="item_id", value_type=feast.ValueType.INT32)]
)
fs.get_feature_view = MagicMock(return_value=feature_view)
fs._registry.get_feature_view = MagicMock(return_value=feature_view)

fs._registry = feast.feature_store.Registry(None, None, "repo_path")
fs.list_entities = MagicMock(return_value=[item_id])

feast.FeatureStore._registry = MagicMock(return_value=fs._registry)
feast.FeatureStore.get_feature_view = MagicMock(return_value=feature_view)
feast.FeatureStore._registry.get_feature_view = MagicMock(return_value=feature_view)
feast.FeatureStore.materialize_incremental = MagicMock(return_value=None)
fs._registry._list_feature_views = MagicMock(return_value=feature_view)
fs._get_feature_view = MagicMock(return_value=feature_view)
expected_input_schema = Schema(
column_schemas=[ColumnSchema(name="item_id", dtype=np.int32)]
)
Expand Down Expand Up @@ -122,7 +131,7 @@ def test_feast_from_feature_view(tmpdir):

@pytest.mark.parametrize("is_ragged", [True, False])
@pytest.mark.parametrize("prefix", ["prefix", ""])
def test_feast_transform(prefix, is_ragged):
def test_feast_transform(tmpdir, prefix, is_ragged):
mocked_resp = OnlineResponse(
online_response_proto=ServingService_pb2.GetOnlineFeaturesResponse(
metadata=ServingService_pb2.GetOnlineFeaturesResponseMetadata(
Expand All @@ -134,10 +143,18 @@ def test_feast_transform(prefix, is_ragged):
ServingService_pb2.GetOnlineFeaturesResponse.FeatureVector(
values=[
Value_pb2.Value(int32_val=1),
]
),
ServingService_pb2.GetOnlineFeaturesResponse.FeatureVector(
values=[
Value_pb2.Value(float_val=1.0),
]
),
ServingService_pb2.GetOnlineFeaturesResponse.FeatureVector(
values=[
Value_pb2.Value(float_list_val=Value_pb2.FloatList(val=[1.0, 2.0, 3.0])),
]
)
),
],
)
)
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ sitepackages=true
setenv =
TF_GPU_ALLOCATOR=cuda_malloc_async
deps =
-rrequirements/test-gpu.txt
pytest
pytest-cov
commands =
Expand Down