Skip to content

Commit f9baf97

Browse files
authored
feat: Enable write node for compute engine (feast-dev#5287)
* enable write node Signed-off-by: HaoXuAI <[email protected]> * fix linting Signed-off-by: HaoXuAI <[email protected]> * remove debug Signed-off-by: HaoXuAI <[email protected]> * rename module Signed-off-by: HaoXuAI <[email protected]> * fix linting Signed-off-by: HaoXuAI <[email protected]> * fix linting Signed-off-by: HaoXuAI <[email protected]> * fix linting Signed-off-by: HaoXuAI <[email protected]> * fix fv offline Signed-off-by: HaoXuAI <[email protected]> * fix feature view proto Signed-off-by: HaoXuAI <[email protected]> * fix write node Signed-off-by: HaoXuAI <[email protected]> * fix write node Signed-off-by: HaoXuAI <[email protected]> --------- Signed-off-by: HaoXuAI <[email protected]>
1 parent a1388a5 commit f9baf97

File tree

14 files changed

+255
-77
lines changed

14 files changed

+255
-77
lines changed

protos/feast/core/FeatureView.proto

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,11 @@ message FeatureViewSpec {
7474
DataSource stream_source = 9;
7575

7676
// Whether these features should be served online or not
77+
// This is also used to determine whether the features should be written to the online store
7778
bool online = 8;
79+
80+
// Whether these features should be written to the offline store
81+
bool offline = 13;
7882
}
7983

8084
message FeatureViewMeta {

sdk/python/feast/feature_view.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def __copy__(self):
236236
schema=self.schema,
237237
tags=self.tags,
238238
online=self.online,
239+
offline=self.offline,
239240
)
240241

241242
# This is deliberately set outside of the FV initialization as we do not have the Entity objects.
@@ -258,6 +259,7 @@ def __eq__(self, other):
258259
sorted(self.entities) != sorted(other.entities)
259260
or self.ttl != other.ttl
260261
or self.online != other.online
262+
or self.offline != other.offline
261263
or self.batch_source != other.batch_source
262264
or self.stream_source != other.stream_source
263265
or sorted(self.entity_columns) != sorted(other.entity_columns)
@@ -363,6 +365,7 @@ def to_proto(self) -> FeatureViewProto:
363365
owner=self.owner,
364366
ttl=(ttl_duration if ttl_duration is not None else None),
365367
online=self.online,
368+
offline=self.offline,
366369
batch_source=batch_source_proto,
367370
stream_source=stream_source_proto,
368371
)
@@ -412,6 +415,7 @@ def from_proto(cls, feature_view_proto: FeatureViewProto):
412415
tags=dict(feature_view_proto.spec.tags),
413416
owner=feature_view_proto.spec.owner,
414417
online=feature_view_proto.spec.online,
418+
offline=feature_view_proto.spec.offline,
415419
ttl=(
416420
timedelta(days=0)
417421
if feature_view_proto.spec.ttl.ToNanoseconds() == 0
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from dataclasses import dataclass
2+
3+
import dill
4+
5+
from feast import FeatureView
6+
from feast.infra.passthrough_provider import PassthroughProvider
7+
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
8+
9+
10+
@dataclass
11+
class SerializedArtifacts:
12+
"""Class to assist with serializing unpicklable artifacts to be passed to the compute engine."""
13+
14+
feature_view_proto: str
15+
repo_config_byte: str
16+
17+
@classmethod
18+
def serialize(cls, feature_view, repo_config):
19+
# serialize to proto
20+
feature_view_proto = feature_view.to_proto().SerializeToString()
21+
22+
# serialize repo_config to disk. Will be used to instantiate the online store
23+
repo_config_byte = dill.dumps(repo_config)
24+
25+
return SerializedArtifacts(
26+
feature_view_proto=feature_view_proto, repo_config_byte=repo_config_byte
27+
)
28+
29+
def unserialize(self):
30+
# unserialize
31+
proto = FeatureViewProto()
32+
proto.ParseFromString(self.feature_view_proto)
33+
feature_view = FeatureView.from_proto(proto)
34+
35+
# load
36+
repo_config = dill.loads(self.repo_config_byte)
37+
38+
provider = PassthroughProvider(repo_config)
39+
online_store = provider.online_store
40+
offline_store = provider.offline_store
41+
return feature_view, online_store, offline_store, repo_config

sdk/python/feast/infra/compute_engines/local/feature_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,6 @@ def build_validation_node(self, input_node):
9191
return node
9292

9393
def build_output_nodes(self, input_node):
94-
node = LocalOutputNode("output")
94+
node = LocalOutputNode("output", self.feature_view)
9595
node.add_input(input_node)
9696
self.nodes.append(node)

sdk/python/feast/infra/compute_engines/local/nodes.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from datetime import datetime, timedelta
2-
from typing import Optional
2+
from typing import Optional, Union
33

44
import pyarrow as pa
55

6+
from feast import BatchFeatureView, StreamFeatureView
67
from feast.data_source import DataSource
78
from feast.infra.compute_engines.dag.context import ExecutionContext
89
from feast.infra.compute_engines.local.arrow_table_value import ArrowTableValue
@@ -11,6 +12,7 @@
1112
from feast.infra.offline_stores.offline_utils import (
1213
infer_event_timestamp_from_entity_df,
1314
)
15+
from feast.utils import _convert_arrow_to_proto
1416

1517
ENTITY_TS_ALIAS = "__entity_event_timestamp"
1618

@@ -207,11 +209,42 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue:
207209

208210

209211
class LocalOutputNode(LocalNode):
210-
def __init__(self, name: str):
212+
def __init__(
213+
self, name: str, feature_view: Union[BatchFeatureView, StreamFeatureView]
214+
):
211215
super().__init__(name)
216+
self.feature_view = feature_view
212217

213218
def execute(self, context: ExecutionContext) -> ArrowTableValue:
214219
input_table = self.get_single_table(context).data
215220
context.node_outputs[self.name] = input_table
216-
# TODO: implement the logic to write to offline store
221+
222+
if self.feature_view.online:
223+
online_store = context.online_store
224+
225+
join_key_to_value_type = {
226+
entity.name: entity.dtype.to_value_type()
227+
for entity in self.feature_view.entity_columns
228+
}
229+
230+
rows_to_write = _convert_arrow_to_proto(
231+
input_table, self.feature_view, join_key_to_value_type
232+
)
233+
234+
online_store.online_write_batch(
235+
config=context.repo_config,
236+
table=self.feature_view,
237+
data=rows_to_write,
238+
progress=lambda x: None,
239+
)
240+
241+
if self.feature_view.offline:
242+
offline_store = context.offline_store
243+
offline_store.offline_write_batch(
244+
config=context.repo_config,
245+
feature_view=self.feature_view,
246+
table=input_table,
247+
progress=lambda x: None,
248+
)
249+
217250
return input_table

sdk/python/feast/infra/compute_engines/spark/feature_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from feast.infra.common.materialization_job import MaterializationTask
66
from feast.infra.common.retrieval_task import HistoricalRetrievalTask
77
from feast.infra.compute_engines.feature_builder import FeatureBuilder
8-
from feast.infra.compute_engines.spark.node import (
8+
from feast.infra.compute_engines.spark.nodes import (
99
SparkAggregationNode,
1010
SparkDedupNode,
1111
SparkFilterNode,
@@ -73,7 +73,8 @@ def build_transformation_node(self, input_node):
7373
return node
7474

7575
def build_output_nodes(self, input_node):
76-
node = SparkWriteNode("output", input_node, self.feature_view)
76+
node = SparkWriteNode("output", self.feature_view)
77+
node.add_input(input_node)
7778
self.nodes.append(node)
7879
return node
7980

sdk/python/feast/infra/compute_engines/spark/node.py renamed to sdk/python/feast/infra/compute_engines/spark/nodes.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
from feast import BatchFeatureView, StreamFeatureView
88
from feast.aggregation import Aggregation
99
from feast.data_source import DataSource
10+
from feast.infra.common.serde import SerializedArtifacts
1011
from feast.infra.compute_engines.dag.context import ExecutionContext
1112
from feast.infra.compute_engines.dag.model import DAGFormat
1213
from feast.infra.compute_engines.dag.node import DAGNode
1314
from feast.infra.compute_engines.dag.value import DAGValue
14-
from feast.infra.materialization.contrib.spark.spark_materialization_engine import (
15-
_map_by_partition,
16-
_SparkSerializedArtifacts,
17-
)
15+
from feast.infra.compute_engines.spark.utils import map_in_arrow
1816
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
1917
SparkRetrievalJob,
2018
_get_entity_schema,
2119
)
20+
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
21+
SparkSource,
22+
)
2223
from feast.infra.offline_stores.offline_utils import (
2324
infer_event_timestamp_from_entity_df,
2425
)
@@ -273,30 +274,41 @@ class SparkWriteNode(DAGNode):
273274
def __init__(
274275
self,
275276
name: str,
276-
input_node: DAGNode,
277277
feature_view: Union[BatchFeatureView, StreamFeatureView],
278278
):
279279
super().__init__(name)
280-
self.add_input(input_node)
281280
self.feature_view = feature_view
282281

283282
def execute(self, context: ExecutionContext) -> DAGValue:
284283
spark_df: DataFrame = self.get_single_input_value(context).data
285-
spark_serialized_artifacts = _SparkSerializedArtifacts.serialize(
284+
serialized_artifacts = SerializedArtifacts.serialize(
286285
feature_view=self.feature_view, repo_config=context.repo_config
287286
)
288287

289-
# ✅ 1. Write to offline store (if enabled)
290-
if self.feature_view.offline:
291-
# TODO: Update _map_by_partition to be able to write to offline store
292-
pass
293-
294-
# ✅ 2. Write to online store (if enabled)
288+
# ✅ 1. Write to online store if online enabled
295289
if self.feature_view.online:
296-
spark_df.mapInPandas(
297-
lambda x: _map_by_partition(x, spark_serialized_artifacts), "status int"
290+
spark_df.mapInArrow(
291+
lambda x: map_in_arrow(x, serialized_artifacts, mode="online"),
292+
spark_df.schema,
298293
).count()
299294

295+
# ✅ 2. Write to offline store if offline enabled
296+
if self.feature_view.offline:
297+
if not isinstance(self.feature_view.batch_source, SparkSource):
298+
spark_df.mapInArrow(
299+
lambda x: map_in_arrow(x, serialized_artifacts, mode="offline"),
300+
spark_df.schema,
301+
).count()
302+
# Directly write spark df to spark offline store without using mapInArrow
303+
else:
304+
dest_path = self.feature_view.batch_source.path
305+
file_format = self.feature_view.batch_source.file_format
306+
if not dest_path or not file_format:
307+
raise ValueError(
308+
"Destination path and file format must be specified for SparkSource."
309+
)
310+
spark_df.write.format(file_format).mode("append").save(dest_path)
311+
300312
return DAGValue(
301313
data=spark_df,
302314
format=DAGFormat.SPARK,

sdk/python/feast/infra/compute_engines/spark/utils.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
from typing import Dict, Optional
1+
from typing import Dict, Iterable, Literal, Optional
22

3+
import pyarrow as pa
34
from pyspark import SparkConf
45
from pyspark.sql import SparkSession
56

7+
from feast.infra.common.serde import SerializedArtifacts
8+
from feast.utils import _convert_arrow_to_proto
9+
610

711
def get_or_create_new_spark_session(
812
spark_config: Optional[Dict[str, str]] = None,
@@ -16,4 +20,47 @@ def get_or_create_new_spark_session(
1620
)
1721

1822
spark_session = spark_builder.getOrCreate()
23+
spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
1924
return spark_session
25+
26+
27+
def map_in_arrow(
28+
iterator: Iterable[pa.RecordBatch],
29+
serialized_artifacts: "SerializedArtifacts",
30+
mode: Literal["online", "offline"] = "online",
31+
):
32+
for batch in iterator:
33+
table = pa.Table.from_batches([batch])
34+
35+
(
36+
feature_view,
37+
online_store,
38+
offline_store,
39+
repo_config,
40+
) = serialized_artifacts.unserialize()
41+
42+
if mode == "online":
43+
join_key_to_value_type = {
44+
entity.name: entity.dtype.to_value_type()
45+
for entity in feature_view.entity_columns
46+
}
47+
48+
rows_to_write = _convert_arrow_to_proto(
49+
table, feature_view, join_key_to_value_type
50+
)
51+
52+
online_store.online_write_batch(
53+
config=repo_config,
54+
table=feature_view,
55+
data=rows_to_write,
56+
progress=lambda x: None,
57+
)
58+
if mode == "offline":
59+
offline_store.offline_write_batch(
60+
config=repo_config,
61+
feature_view=feature_view,
62+
table=table,
63+
progress=lambda x: None,
64+
)
65+
66+
yield batch

0 commit comments

Comments
 (0)