Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from types import MethodType
from typing import List, Optional, no_type_check
from typing import Dict, List, Optional, no_type_check

import pandas as pd
from pyspark.sql import DataFrame, SparkSession
Expand Down Expand Up @@ -29,6 +29,7 @@ class SparkKafkaProcessor(StreamProcessor):
format: str
preprocess_fn: Optional[MethodType]
join_keys: List[str]
stream_source_options: Optional[Dict[str, str]]

def __init__(
self,
Expand All @@ -37,6 +38,7 @@ def __init__(
sfv: StreamFeatureView,
config: ProcessorConfig,
preprocess_fn: Optional[MethodType] = None,
stream_source_options: Optional[Dict[str, str]],
):
if not isinstance(sfv.stream_source, KafkaSource):
raise ValueError("data source is not kafka source")
Expand All @@ -59,6 +61,7 @@ def __init__(
raise ValueError("config is not spark processor config")
self.spark = config.spark_session
self.preprocess_fn = preprocess_fn
self.stream_source_options = stream_source_options
self.processing_time = config.processing_time
self.query_timeout = config.query_timeout
self.join_keys = [fs.get_entity(entity).join_key for entity in sfv.entities]
Expand All @@ -80,19 +83,23 @@ def ingest_stream_feature_view(
@no_type_check
def _ingest_stream_data(self) -> StreamTable:
"""Only supports json and avro formats currently."""
kafka_options: Dict[str, str] = {
"kafka.bootstrap.servers": self.data_source.kafka_options.kafka_bootstrap_servers,
"subscribe": self.data_source.kafka_options.topic,
"startingOffsets": "latest",
}
if self.stream_source_options:
# Update user-provided options to override defaults
kafka_options.update(self.stream_source_options)

if self.format == "json":
if not isinstance(
self.data_source.kafka_options.message_format, JsonFormat
):
raise ValueError("kafka source message format is not jsonformat")
stream_df = (
self.spark.readStream.format("kafka")
.option(
"kafka.bootstrap.servers",
self.data_source.kafka_options.kafka_bootstrap_servers,
)
.option("subscribe", self.data_source.kafka_options.topic)
.option("startingOffsets", "latest") # Query start
.options(**kafka_options)
.load()
.selectExpr("CAST(value AS STRING)")
.select(
Expand All @@ -110,12 +117,7 @@ def _ingest_stream_data(self) -> StreamTable:
raise ValueError("kafka source message format is not avro format")
stream_df = (
self.spark.readStream.format("kafka")
.option(
"kafka.bootstrap.servers",
self.data_source.kafka_options.kafka_bootstrap_servers,
)
.option("subscribe", self.data_source.kafka_options.topic)
.option("startingOffsets", "latest") # Query start
.options(**kafka_options)
.load()
.selectExpr("CAST(value AS STRING)")
.select(
Expand Down
Loading