Skip to content

Commit cc2a46d

Browse files
committed
feat: Batch Embedding at scale for RAG with Ray
Signed-off-by: ntkathole <[email protected]>
1 parent 2390d2e commit cc2a46d

File tree

24 files changed

+1310
-65
lines changed

24 files changed

+1310
-65
lines changed

protos/feast/core/Transformation.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ message UserDefinedFunctionV2 {
1616
// The string representation of the udf
1717
string body_text = 3;
1818

19-
// The transformation mode (e.g., "python", "pandas", "spark", "sql")
19+
// The transformation mode (e.g., "python", "pandas", "ray", "spark", "sql")
2020
string mode = 4;
2121
}
2222

sdk/python/feast/batch_feature_view.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class BatchFeatureView(FeatureView):
5757
"""
5858

5959
name: str
60-
mode: Union[TransformationMode, str]
6160
entities: List[str]
6261
ttl: Optional[timedelta]
6362
source: DataSource
@@ -146,7 +145,8 @@ def get_feature_transformation(self) -> Optional[Transformation]:
146145
TransformationMode.PANDAS,
147146
TransformationMode.PYTHON,
148147
TransformationMode.SQL,
149-
) or self.mode in ("pandas", "python", "sql"):
148+
TransformationMode.RAY,
149+
) or self.mode in ("pandas", "python", "sql", "ray"):
150150
return Transformation(
151151
mode=self.mode, udf=self.udf, udf_string=self.udf_string or ""
152152
)

sdk/python/feast/cli/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def materialize_incremental_command(ctx: click.Context, end_ts: str, views: List
411411
"couchbase",
412412
"milvus",
413413
"ray",
414+
"ray_rag",
414415
],
415416
case_sensitive=False,
416417
),

sdk/python/feast/feature_view.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ class FeatureView(BaseFeatureView):
8787
tags: A dictionary of key-value pairs to store arbitrary metadata.
8888
owner: The owner of the feature view, typically the email of the primary
8989
maintainer.
90+
mode: The transformation mode for feature transformations. Only meaningful when
91+
transformations are applied. Choose from TransformationMode enum values
92+
(e.g., PYTHON, PANDAS, RAY, SQL, SPARK, SUBSTRAIT).
9093
"""
9194

9295
name: str
@@ -143,7 +146,8 @@ def __init__(
143146
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
144147
owner (optional): The owner of the feature view, typically the email of the
145148
primary maintainer.
146-
mode (optional): The transformation mode to use (e.g., python, pandas, spark, sql).
149+
mode (optional): The transformation mode for feature transformations. Only meaningful
150+
when transformations are applied. Choose from TransformationMode enum values.
147151
148152
Raises:
149153
ValueError: A field mapping conflicts with an Entity or a Feature.
@@ -152,6 +156,7 @@ def __init__(
152156
self.entities = [e.name for e in entities] if entities else [DUMMY_ENTITY_NAME]
153157
self.ttl = ttl
154158
schema = schema or []
159+
self.mode = mode
155160

156161
# Normalize source
157162
self.stream_source = None

sdk/python/feast/infra/compute_engines/feature_builder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ def get_column_info(
154154
)
155155
field_mapping = self.get_field_mapping(self.task.feature_view)
156156

157+
# For feature views with transformations that need access to all source columns,
158+
# we need to read ALL source columns, not just the output feature columns.
159+
# This is specifically for transformations that create new columns or need raw data.
160+
mode = getattr(getattr(view, "feature_transformation", None), "mode", None)
161+
if mode == "ray" or getattr(mode, "value", None) == "ray":
162+
# Signal to read all columns by passing empty list for feature_cols
163+
# The transformation will produce the output columns defined in the schema
164+
feature_cols = []
165+
157166
return ColumnInfo(
158167
join_keys=join_keys,
159168
feature_cols=feature_cols,

sdk/python/feast/infra/compute_engines/ray/feature_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def build_output_nodes(self, view, final_node):
161161
name="output",
162162
feature_view=view,
163163
inputs=[final_node],
164+
config=self.config,
164165
)
165166

166167
self.nodes.append(node)
@@ -275,6 +276,7 @@ def _build_materialization_plan(self) -> ExecutionPlan:
275276
name=f"{view.name}:write",
276277
feature_view=view,
277278
inputs=[processing_node],
279+
config=self.config,
278280
)
279281

280282
view_to_write_node[view.name] = write_node

sdk/python/feast/infra/compute_engines/ray/nodes.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ def join_with_aggregated_features(batch: pd.DataFrame) -> pd.DataFrame:
173173
return result
174174

175175
joined_dataset = entity_dataset.map_batches(
176-
join_with_aggregated_features, batch_format="pandas"
176+
join_with_aggregated_features,
177+
batch_format="pandas",
178+
concurrency=self.config.max_workers or 12,
177179
)
178180
else:
179181
if feature_size <= self.config.broadcast_join_threshold_mb * 1024 * 1024:
@@ -274,8 +276,8 @@ def apply_filters(batch: pd.DataFrame) -> pd.DataFrame:
274276
else:
275277
# Use current time for TTL calculation (real-time retrieval)
276278
# Check if timestamp column is timezone-aware
277-
if pd.api.types.is_datetime64tz_dtype(
278-
filtered_batch[timestamp_col]
279+
if isinstance(
280+
filtered_batch[timestamp_col].dtype, pd.DatetimeTZDtype
279281
):
280282
# Use timezone-aware current time
281283
current_time = datetime.now(timezone.utc)
@@ -517,31 +519,59 @@ def execute(self, context: ExecutionContext) -> DAGValue:
517519
input_value.assert_format(DAGFormat.RAY)
518520
dataset: Dataset = input_value.data
519521

520-
transformation_serialized = None
521-
if hasattr(self.transformation, "udf") and callable(self.transformation.udf):
522-
transformation_serialized = dill.dumps(self.transformation.udf)
523-
elif callable(self.transformation):
524-
transformation_serialized = dill.dumps(self.transformation)
522+
# Check transformation mode
523+
from feast.transformation.mode import TransformationMode
525524

526-
@safe_batch_processor
527-
def apply_transformation_with_serialized_udf(
528-
batch: pd.DataFrame,
529-
) -> pd.DataFrame:
530-
"""Apply the transformation using pre-serialized UDF."""
531-
if transformation_serialized:
532-
transformation_func = dill.loads(transformation_serialized)
533-
transformed_batch = transformation_func(batch)
525+
transformation_mode = getattr(
526+
self.transformation, "mode", TransformationMode.PYTHON
527+
)
528+
is_ray_native = transformation_mode in (TransformationMode.RAY, "ray")
529+
if is_ray_native:
530+
transformation_func = None
531+
if hasattr(self.transformation, "udf") and callable(
532+
self.transformation.udf
533+
):
534+
transformation_func = self.transformation.udf
535+
elif callable(self.transformation):
536+
transformation_func = self.transformation
537+
538+
if transformation_func:
539+
transformed_dataset = transformation_func(dataset)
534540
else:
535541
logger.warning(
536-
"No serialized transformation available, returning original batch"
542+
"No transformation function available in RAY mode, returning original dataset"
537543
)
538-
transformed_batch = batch
544+
transformed_dataset = dataset
545+
else:
546+
transformation_serialized = None
547+
if hasattr(self.transformation, "udf") and callable(
548+
self.transformation.udf
549+
):
550+
transformation_serialized = dill.dumps(self.transformation.udf)
551+
elif callable(self.transformation):
552+
transformation_serialized = dill.dumps(self.transformation)
539553

540-
return transformed_batch
554+
@safe_batch_processor
555+
def apply_transformation_with_serialized_udf(
556+
batch: pd.DataFrame,
557+
) -> pd.DataFrame:
558+
"""Apply the transformation using pre-serialized UDF."""
559+
if transformation_serialized:
560+
transformation_func = dill.loads(transformation_serialized)
561+
transformed_batch = transformation_func(batch)
562+
else:
563+
logger.warning(
564+
"No serialized transformation available, returning original batch"
565+
)
566+
transformed_batch = batch
541567

542-
transformed_dataset = dataset.map_batches(
543-
apply_transformation_with_serialized_udf, batch_format="pandas"
544-
)
568+
return transformed_batch
569+
570+
transformed_dataset = dataset.map_batches(
571+
apply_transformation_with_serialized_udf,
572+
batch_format="pandas",
573+
concurrency=self.config.max_workers or 12,
574+
)
545575

546576
return DAGValue(
547577
data=transformed_dataset,
@@ -598,7 +628,9 @@ def apply_transformation(batch: pd.DataFrame) -> pd.DataFrame:
598628
return transformation_func(batch)
599629

600630
transformed_dataset = parent_value.data.map_batches(
601-
apply_transformation
631+
apply_transformation,
632+
batch_format="pandas",
633+
concurrency=self.config.max_workers or 12,
602634
)
603635
return DAGValue(
604636
data=transformed_dataset,
@@ -630,9 +662,11 @@ def __init__(
630662
name: str,
631663
feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView],
632664
inputs=None,
665+
config: Optional[RayComputeEngineConfig] = None,
633666
):
634667
super().__init__(name, inputs=inputs)
635668
self.feature_view = feature_view
669+
self.config = config
636670

637671
def execute(self, context: ExecutionContext) -> DAGValue:
638672
"""Execute the write operation."""
@@ -676,7 +710,9 @@ def write_batch_with_serialized_artifacts(batch: pd.DataFrame) -> pd.DataFrame:
676710
return batch
677711

678712
written_dataset = dataset.map_batches(
679-
write_batch_with_serialized_artifacts, batch_format="pandas"
713+
write_batch_with_serialized_artifacts,
714+
batch_format="pandas",
715+
concurrency=self.config.max_workers if self.config else 12,
680716
)
681717
written_dataset = written_dataset.materialize()
682718

0 commit comments

Comments
 (0)