From b1e9927ccc2f337e8360fa04044a91ce3bfc55c2 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 14 Sep 2025 16:06:40 -0700 Subject: [PATCH 1/3] phase3 Signed-off-by: HaoXuAI --- sdk/python/feast/dataframe.py | 20 ++- .../{local => }/backends/base.py | 30 ++++ .../compute_engines/backends/dask_backend.py | 115 ++++++++++++ .../infra/compute_engines/backends/factory.py | 92 ++++++++++ .../{local => }/backends/pandas_backend.py | 18 +- .../backends/polars_backend.py | 90 ++++++++++ .../compute_engines/backends/spark_backend.py | 163 ++++++++++++++++++ .../local/backends/__init__.py | 0 .../compute_engines/local/backends/factory.py | 53 ------ .../local/backends/polars_backend.py | 47 ----- 10 files changed, 526 insertions(+), 102 deletions(-) rename sdk/python/feast/infra/compute_engines/{local => }/backends/base.py (77%) create mode 100644 sdk/python/feast/infra/compute_engines/backends/dask_backend.py create mode 100644 sdk/python/feast/infra/compute_engines/backends/factory.py rename sdk/python/feast/infra/compute_engines/{local => }/backends/pandas_backend.py (61%) create mode 100644 sdk/python/feast/infra/compute_engines/backends/polars_backend.py create mode 100644 sdk/python/feast/infra/compute_engines/backends/spark_backend.py delete mode 100644 sdk/python/feast/infra/compute_engines/local/backends/__init__.py delete mode 100644 sdk/python/feast/infra/compute_engines/local/backends/factory.py delete mode 100644 sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py diff --git a/sdk/python/feast/dataframe.py b/sdk/python/feast/dataframe.py index 0a54a11c232..dc36ca3127c 100644 --- a/sdk/python/feast/dataframe.py +++ b/sdk/python/feast/dataframe.py @@ -1,11 +1,29 @@ """FeastDataFrame: A lightweight container for DataFrame-like objects in Feast.""" from enum import Enum -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Union import pandas as pd import pyarrow as pa +if TYPE_CHECKING: + import dask.dataframe + import polars + import pyspark.sql + import ray.data + + +# Type alias for entity DataFrame - supports various DataFrame types +DataFrameType = Union[ + "pd.DataFrame", + "pyspark.sql.DataFrame", + "dask.dataframe.DataFrame", + "polars.DataFrame", + "ray.data.Dataset", + "pa.Table", + str, +] + class DataFrameEngine(str, Enum): """Supported DataFrame engines.""" diff --git a/sdk/python/feast/infra/compute_engines/local/backends/base.py b/sdk/python/feast/infra/compute_engines/backends/base.py similarity index 77% rename from sdk/python/feast/infra/compute_engines/local/backends/base.py rename to sdk/python/feast/infra/compute_engines/backends/base.py index 3c8d25abe00..38d926353d2 100644 --- a/sdk/python/feast/infra/compute_engines/local/backends/base.py +++ b/sdk/python/feast/infra/compute_engines/backends/base.py @@ -1,5 +1,9 @@ from abc import ABC, abstractmethod from datetime import timedelta +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np class DataFrameBackend(ABC): @@ -18,6 +22,8 @@ class DataFrameBackend(ABC): Expected implementations include: - PandasBackend - PolarsBackend + - SparkBackend + - DaskBackend - DuckDBBackend (future) Methods @@ -77,3 +83,27 @@ def drop_duplicates(self, df, keys, sort_by, ascending: bool = False): @abstractmethod def rename_columns(self, df, columns: dict[str, str]): ... + + @abstractmethod + def get_schema(self, df) -> dict[str, "np.dtype"]: + """ + Get the schema of the DataFrame as a dictionary of column names to numpy data types. + + Returns: + Dictionary mapping column names to their numpy dtype objects + """ + ... + + @abstractmethod + def get_timestamp_range(self, df, timestamp_column: str) -> tuple: + """ + Get the min and max values of a timestamp column. + + Args: + df: The DataFrame + timestamp_column: Name of the timestamp column + + Returns: + Tuple of (min_timestamp, max_timestamp) as datetime objects + """ + ... diff --git a/sdk/python/feast/infra/compute_engines/backends/dask_backend.py b/sdk/python/feast/infra/compute_engines/backends/dask_backend.py new file mode 100644 index 00000000000..b65bed06c1b --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/backends/dask_backend.py @@ -0,0 +1,115 @@ +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +import dask.dataframe as dd +import pyarrow as pa + +from feast.infra.compute_engines.backends.base import DataFrameBackend + +if TYPE_CHECKING: + import numpy as np + + +class DaskBackend(DataFrameBackend): + """Dask DataFrame backend implementation.""" + + def __init__(self): + if dd is None: + raise ImportError( + "Dask is not installed. Please install it to use Dask backend." + ) + + def columns(self, df: dd.DataFrame) -> List[str]: + """Get column names from Dask DataFrame.""" + return df.columns.tolist() + + def from_arrow(self, table: pa.Table) -> dd.DataFrame: + """Convert Arrow table to Dask DataFrame.""" + pandas_df = table.to_pandas() + return dd.from_pandas(pandas_df, npartitions=1) + + def to_arrow(self, df: dd.DataFrame) -> pa.Table: + """Convert Dask DataFrame to Arrow table.""" + pandas_df = df.compute() + return pa.Table.from_pandas(pandas_df) + + def join( + self, left: dd.DataFrame, right: dd.DataFrame, on: List[str], how: str + ) -> dd.DataFrame: + """Join two Dask DataFrames.""" + return left.merge(right, on=on, how=how) + + def groupby_agg( + self, + df: dd.DataFrame, + group_keys: List[str], + agg_ops: Dict[str, Tuple[str, str]], + ) -> dd.DataFrame: + """Group and aggregate Dask DataFrame.""" + + # Convert agg_ops to pandas-style aggregation + agg_dict = {col_name: func for alias, (func, col_name) in agg_ops.items()} + + result = df.groupby(group_keys).agg(agg_dict).reset_index() + + # Rename columns to match expected aliases + rename_mapping = {} + for alias, (func, col_name) in agg_ops.items(): + if func in ["sum", "count", "mean", "min", "max"]: + old_name = f"{col_name}" + if old_name in result.columns: + rename_mapping[old_name] = alias + + if rename_mapping: + result = result.rename(columns=rename_mapping) + + return result + + def filter(self, df: dd.DataFrame, expr: str) -> dd.DataFrame: + """Apply filter expression to Dask DataFrame.""" + return df.query(expr) + + def to_timedelta_value(self, delta: timedelta) -> Any: + """Convert timedelta to Dask-compatible timedelta.""" + import pandas as pd + + return pd.to_timedelta(delta) + + def drop_duplicates( + self, + df: dd.DataFrame, + keys: List[str], + sort_by: List[str], + ascending: bool = False, + ) -> dd.DataFrame: + """Deduplicate Dask DataFrame by keys, sorted by sort_by columns.""" + return df.drop_duplicates(subset=keys).reset_index(drop=True) + + def rename_columns(self, df: dd.DataFrame, columns: Dict[str, str]) -> dd.DataFrame: + """Rename columns in Dask DataFrame.""" + return df.rename(columns=columns) + + def get_schema(self, df: dd.DataFrame) -> Dict[str, "np.dtype"]: + """Get Dask DataFrame schema as column name to numpy dtype mapping.""" + # Dask dtypes are pandas-compatible numpy dtypes + return {col: dtype for col, dtype in df.dtypes.items()} + + def get_timestamp_range(self, df: dd.DataFrame, timestamp_column: str) -> tuple: + """Get min/max of a timestamp column in Dask DataFrame.""" + import pandas as pd + + col = df[timestamp_column] + # Ensure it's datetime type + if not pd.api.types.is_datetime64_any_dtype(col.dtype): + col = dd.to_datetime(col, utc=True) + + min_val = col.min().compute() + max_val = col.max().compute() + + # Convert to datetime objects + if hasattr(min_val, "to_pydatetime"): + min_val = min_val.to_pydatetime() + if hasattr(max_val, "to_pydatetime"): + max_val = max_val.to_pydatetime() + + return (min_val, max_val) diff --git a/sdk/python/feast/infra/compute_engines/backends/factory.py b/sdk/python/feast/infra/compute_engines/backends/factory.py new file mode 100644 index 00000000000..05d4b56975d --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/backends/factory.py @@ -0,0 +1,92 @@ +from typing import Optional + +import pandas as pd +import pyarrow + +from feast.infra.compute_engines.backends.base import DataFrameBackend +from feast.infra.compute_engines.backends.pandas_backend import PandasBackend + + +class BackendFactory: + """ + Factory class for constructing DataFrameBackend implementations based on backend name + or runtime entity_df type. + """ + + @staticmethod + def from_name(name: str) -> DataFrameBackend: + if name == "pandas": + return PandasBackend() + if name == "polars": + return BackendFactory._get_polars_backend() + if name == "spark": + return BackendFactory._get_spark_backend() + if name == "dask": + return BackendFactory._get_dask_backend() + raise ValueError(f"Unsupported backend name: {name}") + + @staticmethod + def infer_from_entity_df(entity_df) -> Optional[DataFrameBackend]: + if ( + not entity_df + or isinstance(entity_df, pyarrow.Table) + or isinstance(entity_df, pd.DataFrame) + ): + return PandasBackend() + + if BackendFactory._is_polars(entity_df): + return BackendFactory._get_polars_backend() + + if BackendFactory._is_spark(entity_df): + return BackendFactory._get_spark_backend() + + if BackendFactory._is_dask(entity_df): + return BackendFactory._get_dask_backend() + + return None + + @staticmethod + def _is_polars(entity_df) -> bool: + try: + import polars as pl + except ImportError: + raise ImportError( + "Polars is not installed. Please install it to use Polars backend." + ) + return isinstance(entity_df, pl.DataFrame) + + @staticmethod + def _get_polars_backend(): + from feast.infra.compute_engines.backends.polars_backend import ( + PolarsBackend, + ) + + return PolarsBackend() + + @staticmethod + def _is_spark(entity_df) -> bool: + try: + from pyspark.sql import DataFrame as SparkDataFrame + except ImportError: + return False + return isinstance(entity_df, SparkDataFrame) + + @staticmethod + def _get_spark_backend(): + from feast.infra.compute_engines.backends.spark_backend import SparkBackend + + return SparkBackend() + + @staticmethod + def _is_dask(entity_df) -> bool: + try: + import dask.dataframe as dd + except ImportError: + return False + return isinstance(entity_df, dd.DataFrame) + + @staticmethod + def _get_dask_backend(): + from feast.infra.compute_engines.backends.dask_backend import DaskBackend + + return DaskBackend() diff --git a/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py b/sdk/python/feast/infra/compute_engines/backends/pandas_backend.py similarity index 61% rename from sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py rename to sdk/python/feast/infra/compute_engines/backends/pandas_backend.py index 76ddd688424..ff54303632b 100644 --- a/sdk/python/feast/infra/compute_engines/local/backends/pandas_backend.py +++ b/sdk/python/feast/infra/compute_engines/backends/pandas_backend.py @@ -1,9 +1,13 @@ from datetime import timedelta +from typing import TYPE_CHECKING import pandas as pd import pyarrow as pa -from feast.infra.compute_engines.local.backends.base import DataFrameBackend +from feast.infra.compute_engines.backends.base import DataFrameBackend + +if TYPE_CHECKING: + import numpy as np class PandasBackend(DataFrameBackend): @@ -44,3 +48,15 @@ def drop_duplicates(self, df, keys, sort_by, ascending: bool = False): def rename_columns(self, df, columns: dict[str, str]): return df.rename(columns=columns) + + def get_schema(self, df) -> dict[str, "np.dtype"]: + """Get pandas DataFrame schema as column name to numpy dtype mapping.""" + return {col: dtype for col, dtype in df.dtypes.items()} + + def get_timestamp_range(self, df, timestamp_column: str) -> tuple: + """Get min/max of a timestamp column in pandas DataFrame.""" + col = df[timestamp_column] + # Ensure it's datetime type + if not pd.api.types.is_datetime64_any_dtype(col): + col = pd.to_datetime(col, utc=True) + return (col.min().to_pydatetime(), col.max().to_pydatetime()) diff --git a/sdk/python/feast/infra/compute_engines/backends/polars_backend.py b/sdk/python/feast/infra/compute_engines/backends/polars_backend.py new file mode 100644 index 00000000000..c1982e66131 --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/backends/polars_backend.py @@ -0,0 +1,90 @@ +from datetime import timedelta +from typing import TYPE_CHECKING + +import polars as pl +import pyarrow as pa + +from feast.infra.compute_engines.backends.base import DataFrameBackend + +if TYPE_CHECKING: + import numpy as np + + +class PolarsBackend(DataFrameBackend): + def columns(self, df: pl.DataFrame) -> list[str]: + return df.columns + + def from_arrow(self, table: pa.Table) -> pl.DataFrame: + return pl.from_arrow(table) + + def to_arrow(self, df: pl.DataFrame) -> pa.Table: + return df.to_arrow() + + def join(self, left: pl.DataFrame, right: pl.DataFrame, on, how) -> pl.DataFrame: + return left.join(right, on=on, how=how) + + def groupby_agg(self, df: pl.DataFrame, group_keys, agg_ops) -> pl.DataFrame: + agg_exprs = [ + getattr(pl.col(col), func)().alias(alias) + for alias, (func, col) in agg_ops.items() + ] + return df.groupby(group_keys).agg(agg_exprs) + + def filter(self, df: pl.DataFrame, expr: str) -> pl.DataFrame: + return df.filter(pl.sql_expr(expr)) + + def to_timedelta_value(self, delta: timedelta): + return pl.duration(milliseconds=delta.total_seconds() * 1000) + + def drop_duplicates( + self, + df: pl.DataFrame, + keys: list[str], + sort_by: list[str], + ascending: bool = False, + ) -> pl.DataFrame: + return df.sort(by=sort_by, descending=not ascending).unique( + subset=keys, keep="first" + ) + + def rename_columns(self, df: pl.DataFrame, columns: dict[str, str]) -> pl.DataFrame: + return df.rename(columns) + + def get_schema(self, df: pl.DataFrame) -> dict[str, "np.dtype"]: + """Get Polars DataFrame schema as column name to numpy dtype mapping.""" + import numpy as np + + # Convert Polars dtypes to numpy dtypes + def polars_to_numpy_dtype(polars_dtype): + dtype_map = { + pl.Boolean: np.dtype("bool"), + pl.Int8: np.dtype("int8"), + pl.Int16: np.dtype("int16"), + pl.Int32: np.dtype("int32"), + pl.Int64: np.dtype("int64"), + pl.UInt8: np.dtype("uint8"), + pl.UInt16: np.dtype("uint16"), + pl.UInt32: np.dtype("uint32"), + pl.UInt64: np.dtype("uint64"), + pl.Float32: np.dtype("float32"), + pl.Float64: np.dtype("float64"), + pl.Utf8: np.dtype("O"), # Object dtype for strings + pl.Date: np.dtype("datetime64[D]"), + pl.Datetime: np.dtype("datetime64[ns]"), + } + return dtype_map.get(polars_dtype, np.dtype("O")) # Default to object + + return {col: polars_to_numpy_dtype(dtype) for col, dtype in df.schema.items()} + + def get_timestamp_range(self, df: pl.DataFrame, timestamp_column: str) -> tuple: + """Get min/max of a timestamp column in Polars DataFrame.""" + min_val = df[timestamp_column].min() + max_val = df[timestamp_column].max() + + # Convert to datetime objects if needed + if hasattr(min_val, "to_py"): + min_val = min_val.to_py() + if hasattr(max_val, "to_py"): + max_val = max_val.to_py() + + return (min_val, max_val) diff --git a/sdk/python/feast/infra/compute_engines/backends/spark_backend.py b/sdk/python/feast/infra/compute_engines/backends/spark_backend.py new file mode 100644 index 00000000000..e79d2250f3d --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/backends/spark_backend.py @@ -0,0 +1,163 @@ +from datetime import timedelta +from typing import TYPE_CHECKING, Dict, List, Tuple + +import pyarrow as pa +from pyspark.sql import DataFrame as SparkDataFrame +from pyspark.sql.functions import asc, col, desc + +from feast.infra.compute_engines.backends.base import DataFrameBackend + +if TYPE_CHECKING: + import numpy as np + + +class SparkBackend(DataFrameBackend): + """Spark DataFrame backend implementation.""" + + def __init__(self): + if ( + not hasattr(SparkDataFrame, "__name__") + or SparkDataFrame.__name__ == "SparkDataFrame" + ): + raise ImportError( + "PySpark is not installed. Please install it to use Spark backend." + ) + + def columns(self, df: SparkDataFrame) -> List[str]: + """Get column names from Spark DataFrame.""" + return df.columns + + def from_arrow(self, table: pa.Table) -> SparkDataFrame: + """Convert Arrow table to Spark DataFrame.""" + from pyspark.sql import SparkSession + + spark = SparkSession.getActiveSession() + if spark is None: + raise RuntimeError("No active Spark session found") + return spark.createDataFrame(table.to_pandas()) + + def to_arrow(self, df: SparkDataFrame) -> pa.Table: + """Convert Spark DataFrame to Arrow table.""" + return pa.Table.from_pandas(df.toPandas()) + + def join( + self, left: SparkDataFrame, right: SparkDataFrame, on: List[str], how: str + ) -> SparkDataFrame: + """Join two Spark DataFrames.""" + return left.join(right, on=on, how=how) + + def groupby_agg( + self, + df: SparkDataFrame, + group_keys: List[str], + agg_ops: Dict[str, Tuple[str, str]], + ) -> SparkDataFrame: + """Group and aggregate Spark DataFrame.""" + + # Convert agg_ops to Spark aggregation expressions + agg_exprs = {} + for alias, (func, column) in agg_ops.items(): + if func == "count": + agg_exprs[alias] = {"count": column} + elif func == "sum": + agg_exprs[alias] = {"sum": column} + elif func == "avg": + agg_exprs[alias] = {"avg": column} + elif func == "min": + agg_exprs[alias] = {"min": column} + elif func == "max": + agg_exprs[alias] = {"max": column} + else: + raise ValueError(f"Unsupported aggregation function: {func}") + + # Flatten the expressions for Spark's agg() method + spark_agg_exprs = {} + for alias, agg_dict in agg_exprs.items(): + for agg_func, col_name in agg_dict.items(): + spark_agg_exprs[f"{agg_func}({col_name})"] = alias + + result = df.groupBy(*group_keys).agg(spark_agg_exprs) + + # Rename columns to match expected aliases + for old_name, new_name in spark_agg_exprs.items(): + if old_name != new_name: + result = result.withColumnRenamed(old_name, new_name) + + return result + + def filter(self, df: SparkDataFrame, expr: str) -> SparkDataFrame: + """Apply filter expression to Spark DataFrame.""" + return df.filter(expr) + + def to_timedelta_value(self, delta: timedelta) -> str: + """Convert timedelta to Spark interval string.""" + total_seconds = int(delta.total_seconds()) + if total_seconds < 60: + return f"INTERVAL {total_seconds} SECONDS" + elif total_seconds < 3600: + minutes = total_seconds // 60 + seconds = total_seconds % 60 + if seconds == 0: + return f"INTERVAL {minutes} MINUTES" + else: + return f"INTERVAL '{minutes}:{seconds:02d}' MINUTE TO SECOND" + else: + hours = total_seconds // 3600 + remainder = total_seconds % 3600 + minutes = remainder // 60 + seconds = remainder % 60 + return f"INTERVAL '{hours}:{minutes:02d}:{seconds:02d}' HOUR TO SECOND" + + def drop_duplicates( + self, + df: SparkDataFrame, + keys: List[str], + sort_by: List[str], + ascending: bool = False, + ) -> SparkDataFrame: + """Deduplicate Spark DataFrame by keys, sorted by sort_by columns.""" + from pyspark.sql.functions import row_number + from pyspark.sql.window import Window + + # Create window spec for deduplication + window_spec = Window.partitionBy(*keys) + + # Add sort order + if ascending: + window_spec = window_spec.orderBy(*[asc(col) for col in sort_by]) + else: + window_spec = window_spec.orderBy(*[desc(col) for col in sort_by]) + + # Add row number and filter to keep only first row + return ( + df.withColumn("__row_number", row_number().over(window_spec)) + .filter(col("__row_number") == 1) + .drop("__row_number") + ) + + def rename_columns( + self, df: SparkDataFrame, columns: Dict[str, str] + ) -> SparkDataFrame: + """Rename columns in Spark DataFrame.""" + result = df + for old_name, new_name in columns.items(): + result = result.withColumnRenamed(old_name, new_name) + return result + + def get_schema(self, df: SparkDataFrame) -> Dict[str, "np.dtype"]: + """Get Spark DataFrame schema as column name to numpy dtype mapping.""" + from feast.type_map import spark_schema_to_np_dtypes + + return dict(zip(df.columns, spark_schema_to_np_dtypes(df.dtypes))) + + def get_timestamp_range(self, df: SparkDataFrame, timestamp_column: str) -> tuple: + """Get min/max of a timestamp column in Spark DataFrame.""" + from pyspark.sql.functions import max as spark_max + from pyspark.sql.functions import min as spark_min + + result = df.select( + spark_min(col(timestamp_column)).alias("min_ts"), + spark_max(col(timestamp_column)).alias("max_ts"), + ).collect()[0] + + return (result["min_ts"], result["max_ts"]) diff --git a/sdk/python/feast/infra/compute_engines/local/backends/__init__.py b/sdk/python/feast/infra/compute_engines/local/backends/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/sdk/python/feast/infra/compute_engines/local/backends/factory.py b/sdk/python/feast/infra/compute_engines/local/backends/factory.py deleted file mode 100644 index 6d3774f6393..00000000000 --- a/sdk/python/feast/infra/compute_engines/local/backends/factory.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Optional - -import pandas as pd -import pyarrow - -from feast.infra.compute_engines.local.backends.base import DataFrameBackend -from feast.infra.compute_engines.local.backends.pandas_backend import PandasBackend - - -class BackendFactory: - """ - Factory class for constructing DataFrameBackend implementations based on backend name - or runtime entity_df type. - """ - - @staticmethod - def from_name(name: str) -> DataFrameBackend: - if name == "pandas": - return PandasBackend() - if name == "polars": - return BackendFactory._get_polars_backend() - raise ValueError(f"Unsupported backend name: {name}") - - @staticmethod - def infer_from_entity_df(entity_df) -> Optional[DataFrameBackend]: - if ( - not entity_df - or isinstance(entity_df, pyarrow.Table) - or isinstance(entity_df, pd.DataFrame) - ): - return PandasBackend() - - if BackendFactory._is_polars(entity_df): - return BackendFactory._get_polars_backend() - return None - - @staticmethod - def _is_polars(entity_df) -> bool: - try: - import polars as pl - except ImportError: - raise ImportError( - "Polars is not installed. Please install it to use Polars backend." - ) - return isinstance(entity_df, pl.DataFrame) - - @staticmethod - def _get_polars_backend(): - from feast.infra.compute_engines.local.backends.polars_backend import ( - PolarsBackend, - ) - - return PolarsBackend() diff --git a/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py b/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py deleted file mode 100644 index 352ffecdab8..00000000000 --- a/sdk/python/feast/infra/compute_engines/local/backends/polars_backend.py +++ /dev/null @@ -1,47 +0,0 @@ -from datetime import timedelta - -import polars as pl -import pyarrow as pa - -from feast.infra.compute_engines.local.backends.base import DataFrameBackend - - -class PolarsBackend(DataFrameBackend): - def columns(self, df: pl.DataFrame) -> list[str]: - return df.columns - - def from_arrow(self, table: pa.Table) -> pl.DataFrame: - return pl.from_arrow(table) - - def to_arrow(self, df: pl.DataFrame) -> pa.Table: - return df.to_arrow() - - def join(self, left: pl.DataFrame, right: pl.DataFrame, on, how) -> pl.DataFrame: - return left.join(right, on=on, how=how) - - def groupby_agg(self, df: pl.DataFrame, group_keys, agg_ops) -> pl.DataFrame: - agg_exprs = [ - getattr(pl.col(col), func)().alias(alias) - for alias, (func, col) in agg_ops.items() - ] - return df.groupby(group_keys).agg(agg_exprs) - - def filter(self, df: pl.DataFrame, expr: str) -> pl.DataFrame: - return df.filter(pl.sql_expr(expr)) - - def to_timedelta_value(self, delta: timedelta): - return pl.duration(milliseconds=delta.total_seconds() * 1000) - - def drop_duplicates( - self, - df: pl.DataFrame, - keys: list[str], - sort_by: list[str], - ascending: bool = False, - ) -> pl.DataFrame: - return df.sort(by=sort_by, descending=not ascending).unique( - subset=keys, keep="first" - ) - - def rename_columns(self, df: pl.DataFrame, columns: dict[str, str]) -> pl.DataFrame: - return df.rename(columns) From 1528f6f8fdb0c12d9fe5e1b0fad1acfdce9705c2 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 14 Sep 2025 16:13:44 -0700 Subject: [PATCH 2/3] update Signed-off-by: HaoXuAI --- .../infra/compute_engines/backends/factory.py | 19 ++ .../compute_engines/backends/ray_backend.py | 174 ++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 sdk/python/feast/infra/compute_engines/backends/ray_backend.py diff --git a/sdk/python/feast/infra/compute_engines/backends/factory.py b/sdk/python/feast/infra/compute_engines/backends/factory.py index 05d4b56975d..96a9f58de04 100644 --- a/sdk/python/feast/infra/compute_engines/backends/factory.py +++ b/sdk/python/feast/infra/compute_engines/backends/factory.py @@ -23,6 +23,8 @@ def from_name(name: str) -> DataFrameBackend: return BackendFactory._get_spark_backend() if name == "dask": return BackendFactory._get_dask_backend() + if name == "ray": + return BackendFactory._get_ray_backend() raise ValueError(f"Unsupported backend name: {name}") @staticmethod @@ -43,6 +45,9 @@ def infer_from_entity_df(entity_df) -> Optional[DataFrameBackend]: if BackendFactory._is_dask(entity_df): return BackendFactory._get_dask_backend() + if BackendFactory._is_ray(entity_df): + return BackendFactory._get_ray_backend() + return None @staticmethod @@ -90,3 +95,17 @@ def _get_dask_backend(): from feast.infra.compute_engines.backends.dask_backend import DaskBackend return DaskBackend() + + @staticmethod + def _is_ray(entity_df) -> bool: + try: + import ray.data + except ImportError: + return False + return isinstance(entity_df, ray.data.Dataset) + + @staticmethod + def _get_ray_backend(): + from feast.infra.compute_engines.backends.ray_backend import RayBackend + + return RayBackend() diff --git a/sdk/python/feast/infra/compute_engines/backends/ray_backend.py b/sdk/python/feast/infra/compute_engines/backends/ray_backend.py new file mode 100644 index 00000000000..dbe7c9349fa --- /dev/null +++ b/sdk/python/feast/infra/compute_engines/backends/ray_backend.py @@ -0,0 +1,174 @@ +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +import ray.data +import pyarrow as pa + +from feast.infra.compute_engines.backends.base import DataFrameBackend + +if TYPE_CHECKING: + import numpy as np + + +class RayBackend(DataFrameBackend): + """Ray Dataset backend implementation.""" + + def __init__(self): + if not hasattr(ray.data, 'Dataset'): + raise ImportError( + "Ray is not installed. Please install it to use Ray backend." + ) + + def columns(self, df: ray.data.Dataset) -> List[str]: + """Get column names from Ray Dataset.""" + return df.schema().names + + def from_arrow(self, table: pa.Table) -> ray.data.Dataset: + """Convert Arrow table to Ray Dataset.""" + return ray.data.from_arrow(table) + + def to_arrow(self, df: ray.data.Dataset) -> pa.Table: + """Convert Ray Dataset to Arrow table.""" + return df.to_arrow() + + def join( + self, left: ray.data.Dataset, right: ray.data.Dataset, on: List[str], how: str + ) -> ray.data.Dataset: + """Join two Ray Datasets.""" + # Ray doesn't support arbitrary join keys directly, convert via Arrow + left_arrow = self.to_arrow(left) + right_arrow = self.to_arrow(right) + + # Use pandas for the join operation + import pandas as pd + left_pd = left_arrow.to_pandas() + right_pd = right_arrow.to_pandas() + + result_pd = left_pd.merge(right_pd, on=on, how=how) + result_arrow = pa.Table.from_pandas(result_pd) + return self.from_arrow(result_arrow) + + def groupby_agg( + self, + df: ray.data.Dataset, + group_keys: List[str], + agg_ops: Dict[str, Tuple[str, str]], + ) -> ray.data.Dataset: + """Group and aggregate Ray Dataset.""" + # Ray's groupby is limited, so we'll use pandas conversion + arrow_table = self.to_arrow(df) + pandas_df = arrow_table.to_pandas() + + # Use pandas groupby + import pandas as pd + agg_dict = { + col_name: func for alias, (func, col_name) in agg_ops.items() + } + + result = pandas_df.groupby(group_keys).agg(agg_dict).reset_index() + + # Rename columns to match expected aliases + rename_mapping = {} + for alias, (func, col_name) in agg_ops.items(): + if func in ["sum", "count", "mean", "min", "max"]: + old_name = f"{col_name}" + if old_name in result.columns: + rename_mapping[old_name] = alias + + if rename_mapping: + result = result.rename(columns=rename_mapping) + + # Convert back to Ray Dataset + result_arrow = pa.Table.from_pandas(result) + return self.from_arrow(result_arrow) + + def filter(self, df: ray.data.Dataset, expr: str) -> ray.data.Dataset: + """Apply filter expression to Ray Dataset.""" + # Ray has limited SQL support, convert to pandas for complex expressions + arrow_table = self.to_arrow(df) + pandas_df = arrow_table.to_pandas() + filtered_df = pandas_df.query(expr) + result_arrow = pa.Table.from_pandas(filtered_df) + return self.from_arrow(result_arrow) + + def to_timedelta_value(self, delta: timedelta) -> Any: + """Convert timedelta to Ray-compatible timedelta.""" + import pandas as pd + return pd.to_timedelta(delta) + + def drop_duplicates( + self, + df: ray.data.Dataset, + keys: List[str], + sort_by: List[str], + ascending: bool = False, + ) -> ray.data.Dataset: + """Deduplicate Ray Dataset by keys, sorted by sort_by columns.""" + # Convert to pandas for deduplication + arrow_table = self.to_arrow(df) + pandas_df = arrow_table.to_pandas() + + result_df = pandas_df.sort_values(by=sort_by, ascending=ascending).drop_duplicates( + subset=keys + ) + + result_arrow = pa.Table.from_pandas(result_df) + return self.from_arrow(result_arrow) + + def rename_columns(self, df: ray.data.Dataset, columns: Dict[str, str]) -> ray.data.Dataset: + """Rename columns in Ray Dataset.""" + # Ray doesn't have direct rename, so convert via Arrow + arrow_table = self.to_arrow(df) + pandas_df = arrow_table.to_pandas() + renamed_df = pandas_df.rename(columns=columns) + result_arrow = pa.Table.from_pandas(renamed_df) + return self.from_arrow(result_arrow) + + def get_schema(self, df: ray.data.Dataset) -> Dict[str, "np.dtype"]: + """Get Ray Dataset schema as column name to numpy dtype mapping.""" + import numpy as np + + # Convert Ray schema to numpy dtypes + schema = df.schema() + result = {} + + for field in schema: + # Map Arrow types to numpy dtypes + arrow_type = field.type + if pa.types.is_boolean(arrow_type): + numpy_dtype = np.dtype("bool") + elif pa.types.is_integer(arrow_type): + numpy_dtype = np.dtype("int64") + elif pa.types.is_floating(arrow_type): + numpy_dtype = np.dtype("float64") + elif pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): + numpy_dtype = np.dtype("O") # Object dtype for strings + elif pa.types.is_timestamp(arrow_type): + numpy_dtype = np.dtype("datetime64[ns]") + elif pa.types.is_date(arrow_type): + numpy_dtype = np.dtype("datetime64[D]") + else: + numpy_dtype = np.dtype("O") # Default to object + + result[field.name] = numpy_dtype + + return result + + def get_timestamp_range(self, df: ray.data.Dataset, timestamp_column: str) -> tuple: + """Get min/max of a timestamp column in Ray Dataset.""" + # Use Ray's built-in aggregation for min/max + stats = df.aggregate( + ray.data.aggregate.Min(timestamp_column), + ray.data.aggregate.Max(timestamp_column), + ) + + min_val = stats[f"min({timestamp_column})"] + max_val = stats[f"max({timestamp_column})"] + + # Convert to datetime objects if needed + if hasattr(min_val, 'to_pydatetime'): + min_val = min_val.to_pydatetime() + if hasattr(max_val, 'to_pydatetime'): + max_val = max_val.to_pydatetime() + + return (min_val, max_val) \ No newline at end of file From 12d8c6719d9833d17146116d1d52bcaab030cba3 Mon Sep 17 00:00:00 2001 From: HaoXuAI Date: Sun, 14 Sep 2025 16:16:01 -0700 Subject: [PATCH 3/3] linting Signed-off-by: HaoXuAI --- .../compute_engines/backends/ray_backend.py | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/sdk/python/feast/infra/compute_engines/backends/ray_backend.py b/sdk/python/feast/infra/compute_engines/backends/ray_backend.py index dbe7c9349fa..245662c4178 100644 --- a/sdk/python/feast/infra/compute_engines/backends/ray_backend.py +++ b/sdk/python/feast/infra/compute_engines/backends/ray_backend.py @@ -1,8 +1,8 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Dict, List, Tuple -import ray.data import pyarrow as pa +import ray.data from feast.infra.compute_engines.backends.base import DataFrameBackend @@ -14,7 +14,7 @@ class RayBackend(DataFrameBackend): """Ray Dataset backend implementation.""" def __init__(self): - if not hasattr(ray.data, 'Dataset'): + if not hasattr(ray.data, "Dataset"): raise ImportError( "Ray is not installed. Please install it to use Ray backend." ) @@ -38,12 +38,11 @@ def join( # Ray doesn't support arbitrary join keys directly, convert via Arrow left_arrow = self.to_arrow(left) right_arrow = self.to_arrow(right) - + # Use pandas for the join operation - import pandas as pd left_pd = left_arrow.to_pandas() right_pd = right_arrow.to_pandas() - + result_pd = left_pd.merge(right_pd, on=on, how=how) result_arrow = pa.Table.from_pandas(result_pd) return self.from_arrow(result_arrow) @@ -58,15 +57,12 @@ def groupby_agg( # Ray's groupby is limited, so we'll use pandas conversion arrow_table = self.to_arrow(df) pandas_df = arrow_table.to_pandas() - + # Use pandas groupby - import pandas as pd - agg_dict = { - col_name: func for alias, (func, col_name) in agg_ops.items() - } - + agg_dict = {col_name: func for alias, (func, col_name) in agg_ops.items()} + result = pandas_df.groupby(group_keys).agg(agg_dict).reset_index() - + # Rename columns to match expected aliases rename_mapping = {} for alias, (func, col_name) in agg_ops.items(): @@ -74,10 +70,10 @@ def groupby_agg( old_name = f"{col_name}" if old_name in result.columns: rename_mapping[old_name] = alias - + if rename_mapping: result = result.rename(columns=rename_mapping) - + # Convert back to Ray Dataset result_arrow = pa.Table.from_pandas(result) return self.from_arrow(result_arrow) @@ -94,6 +90,7 @@ def filter(self, df: ray.data.Dataset, expr: str) -> ray.data.Dataset: def to_timedelta_value(self, delta: timedelta) -> Any: """Convert timedelta to Ray-compatible timedelta.""" import pandas as pd + return pd.to_timedelta(delta) def drop_duplicates( @@ -107,15 +104,17 @@ def drop_duplicates( # Convert to pandas for deduplication arrow_table = self.to_arrow(df) pandas_df = arrow_table.to_pandas() - - result_df = pandas_df.sort_values(by=sort_by, ascending=ascending).drop_duplicates( - subset=keys - ) - + + result_df = pandas_df.sort_values( + by=sort_by, ascending=ascending + ).drop_duplicates(subset=keys) + result_arrow = pa.Table.from_pandas(result_df) return self.from_arrow(result_arrow) - def rename_columns(self, df: ray.data.Dataset, columns: Dict[str, str]) -> ray.data.Dataset: + def rename_columns( + self, df: ray.data.Dataset, columns: Dict[str, str] + ) -> ray.data.Dataset: """Rename columns in Ray Dataset.""" # Ray doesn't have direct rename, so convert via Arrow arrow_table = self.to_arrow(df) @@ -127,14 +126,15 @@ def rename_columns(self, df: ray.data.Dataset, columns: Dict[str, str]) -> ray.d def get_schema(self, df: ray.data.Dataset) -> Dict[str, "np.dtype"]: """Get Ray Dataset schema as column name to numpy dtype mapping.""" import numpy as np - + # Convert Ray schema to numpy dtypes schema = df.schema() result = {} - + for field in schema: # Map Arrow types to numpy dtypes arrow_type = field.type + numpy_dtype: "np.dtype[Any]" if pa.types.is_boolean(arrow_type): numpy_dtype = np.dtype("bool") elif pa.types.is_integer(arrow_type): @@ -149,9 +149,9 @@ def get_schema(self, df: ray.data.Dataset) -> Dict[str, "np.dtype"]: numpy_dtype = np.dtype("datetime64[D]") else: numpy_dtype = np.dtype("O") # Default to object - + result[field.name] = numpy_dtype - + return result def get_timestamp_range(self, df: ray.data.Dataset, timestamp_column: str) -> tuple: @@ -161,14 +161,14 @@ def get_timestamp_range(self, df: ray.data.Dataset, timestamp_column: str) -> tu ray.data.aggregate.Min(timestamp_column), ray.data.aggregate.Max(timestamp_column), ) - + min_val = stats[f"min({timestamp_column})"] max_val = stats[f"max({timestamp_column})"] - + # Convert to datetime objects if needed - if hasattr(min_val, 'to_pydatetime'): + if hasattr(min_val, "to_pydatetime"): min_val = min_val.to_pydatetime() - if hasattr(max_val, 'to_pydatetime'): + if hasattr(max_val, "to_pydatetime"): max_val = max_val.to_pydatetime() - - return (min_val, max_val) \ No newline at end of file + + return (min_val, max_val)