Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion sdk/python/feast/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
"""FeastDataFrame: A lightweight container for DataFrame-like objects in Feast."""

from enum import Enum
from typing import Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import pandas as pd
import pyarrow as pa

if TYPE_CHECKING:
import dask.dataframe
import polars
import pyspark.sql
import ray.data


# Type alias for entity DataFrame - supports various DataFrame types
DataFrameType = Union[
"pd.DataFrame",
"pyspark.sql.DataFrame",
"dask.dataframe.DataFrame",
"polars.DataFrame",
"ray.data.Dataset",
"pa.Table",
str,
]


class DataFrameEngine(str, Enum):
"""Supported DataFrame engines."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import numpy as np


class DataFrameBackend(ABC):
Expand All @@ -18,6 +22,8 @@ class DataFrameBackend(ABC):
Expected implementations include:
- PandasBackend
- PolarsBackend
- SparkBackend
- DaskBackend
- DuckDBBackend (future)

Methods
Expand Down Expand Up @@ -77,3 +83,27 @@ def drop_duplicates(self, df, keys, sort_by, ascending: bool = False):

@abstractmethod
def rename_columns(self, df, columns: dict[str, str]): ...

@abstractmethod
def get_schema(self, df) -> dict[str, "np.dtype"]:
"""
Get the schema of the DataFrame as a dictionary of column names to numpy data types.

Returns:
Dictionary mapping column names to their numpy dtype objects
"""
...

@abstractmethod
def get_timestamp_range(self, df, timestamp_column: str) -> tuple:
"""
Get the min and max values of a timestamp column.

Args:
df: The DataFrame
timestamp_column: Name of the timestamp column

Returns:
Tuple of (min_timestamp, max_timestamp) as datetime objects
"""
...
115 changes: 115 additions & 0 deletions sdk/python/feast/infra/compute_engines/backends/dask_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Dict, List, Tuple

import dask.dataframe as dd
import pyarrow as pa

from feast.infra.compute_engines.backends.base import DataFrameBackend

if TYPE_CHECKING:
import numpy as np


class DaskBackend(DataFrameBackend):
"""Dask DataFrame backend implementation."""

def __init__(self):
if dd is None:
raise ImportError(
"Dask is not installed. Please install it to use Dask backend."
)

def columns(self, df: dd.DataFrame) -> List[str]:
"""Get column names from Dask DataFrame."""
return df.columns.tolist()

def from_arrow(self, table: pa.Table) -> dd.DataFrame:
"""Convert Arrow table to Dask DataFrame."""
pandas_df = table.to_pandas()
return dd.from_pandas(pandas_df, npartitions=1)

def to_arrow(self, df: dd.DataFrame) -> pa.Table:
"""Convert Dask DataFrame to Arrow table."""
pandas_df = df.compute()
return pa.Table.from_pandas(pandas_df)

def join(
self, left: dd.DataFrame, right: dd.DataFrame, on: List[str], how: str
) -> dd.DataFrame:
"""Join two Dask DataFrames."""
return left.merge(right, on=on, how=how)

def groupby_agg(
self,
df: dd.DataFrame,
group_keys: List[str],
agg_ops: Dict[str, Tuple[str, str]],
) -> dd.DataFrame:
"""Group and aggregate Dask DataFrame."""

# Convert agg_ops to pandas-style aggregation
agg_dict = {col_name: func for alias, (func, col_name) in agg_ops.items()}

result = df.groupby(group_keys).agg(agg_dict).reset_index()

# Rename columns to match expected aliases
rename_mapping = {}
for alias, (func, col_name) in agg_ops.items():
if func in ["sum", "count", "mean", "min", "max"]:
old_name = f"{col_name}"
if old_name in result.columns:
rename_mapping[old_name] = alias

if rename_mapping:
result = result.rename(columns=rename_mapping)

return result

def filter(self, df: dd.DataFrame, expr: str) -> dd.DataFrame:
"""Apply filter expression to Dask DataFrame."""
return df.query(expr)

def to_timedelta_value(self, delta: timedelta) -> Any:
"""Convert timedelta to Dask-compatible timedelta."""
import pandas as pd

return pd.to_timedelta(delta)

def drop_duplicates(
self,
df: dd.DataFrame,
keys: List[str],
sort_by: List[str],
ascending: bool = False,
) -> dd.DataFrame:
"""Deduplicate Dask DataFrame by keys, sorted by sort_by columns."""
return df.drop_duplicates(subset=keys).reset_index(drop=True)

def rename_columns(self, df: dd.DataFrame, columns: Dict[str, str]) -> dd.DataFrame:
"""Rename columns in Dask DataFrame."""
return df.rename(columns=columns)

def get_schema(self, df: dd.DataFrame) -> Dict[str, "np.dtype"]:
"""Get Dask DataFrame schema as column name to numpy dtype mapping."""
# Dask dtypes are pandas-compatible numpy dtypes
return {col: dtype for col, dtype in df.dtypes.items()}

def get_timestamp_range(self, df: dd.DataFrame, timestamp_column: str) -> tuple:
"""Get min/max of a timestamp column in Dask DataFrame."""
import pandas as pd

col = df[timestamp_column]
# Ensure it's datetime type
if not pd.api.types.is_datetime64_any_dtype(col.dtype):
col = dd.to_datetime(col, utc=True)

min_val = col.min().compute()
max_val = col.max().compute()

# Convert to datetime objects
if hasattr(min_val, "to_pydatetime"):
min_val = min_val.to_pydatetime()
if hasattr(max_val, "to_pydatetime"):
max_val = max_val.to_pydatetime()

return (min_val, max_val)
111 changes: 111 additions & 0 deletions sdk/python/feast/infra/compute_engines/backends/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Optional

import pandas as pd
import pyarrow

from feast.infra.compute_engines.backends.base import DataFrameBackend
from feast.infra.compute_engines.backends.pandas_backend import PandasBackend


class BackendFactory:
"""
Factory class for constructing DataFrameBackend implementations based on backend name
or runtime entity_df type.
"""

@staticmethod
def from_name(name: str) -> DataFrameBackend:
if name == "pandas":
return PandasBackend()
if name == "polars":
return BackendFactory._get_polars_backend()
if name == "spark":
return BackendFactory._get_spark_backend()
if name == "dask":
return BackendFactory._get_dask_backend()
if name == "ray":
return BackendFactory._get_ray_backend()
raise ValueError(f"Unsupported backend name: {name}")

@staticmethod
def infer_from_entity_df(entity_df) -> Optional[DataFrameBackend]:
if (
not entity_df
or isinstance(entity_df, pyarrow.Table)
or isinstance(entity_df, pd.DataFrame)
):
return PandasBackend()

if BackendFactory._is_polars(entity_df):
return BackendFactory._get_polars_backend()

if BackendFactory._is_spark(entity_df):
return BackendFactory._get_spark_backend()

if BackendFactory._is_dask(entity_df):
return BackendFactory._get_dask_backend()

if BackendFactory._is_ray(entity_df):
return BackendFactory._get_ray_backend()

return None

@staticmethod
def _is_polars(entity_df) -> bool:
try:
import polars as pl
except ImportError:
raise ImportError(
"Polars is not installed. Please install it to use Polars backend."
)
return isinstance(entity_df, pl.DataFrame)

@staticmethod
def _get_polars_backend():
from feast.infra.compute_engines.backends.polars_backend import (
PolarsBackend,
)

return PolarsBackend()

@staticmethod
def _is_spark(entity_df) -> bool:
try:
from pyspark.sql import DataFrame as SparkDataFrame
except ImportError:
return False
return isinstance(entity_df, SparkDataFrame)

@staticmethod
def _get_spark_backend():
from feast.infra.compute_engines.backends.spark_backend import SparkBackend

return SparkBackend()

@staticmethod
def _is_dask(entity_df) -> bool:
try:
import dask.dataframe as dd
except ImportError:
return False
return isinstance(entity_df, dd.DataFrame)

@staticmethod
def _get_dask_backend():
from feast.infra.compute_engines.backends.dask_backend import DaskBackend

return DaskBackend()

@staticmethod
def _is_ray(entity_df) -> bool:
try:
import ray.data
except ImportError:
return False
return isinstance(entity_df, ray.data.Dataset)

@staticmethod
def _get_ray_backend():
from feast.infra.compute_engines.backends.ray_backend import RayBackend

return RayBackend()
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from datetime import timedelta
from typing import TYPE_CHECKING

import pandas as pd
import pyarrow as pa

from feast.infra.compute_engines.local.backends.base import DataFrameBackend
from feast.infra.compute_engines.backends.base import DataFrameBackend

if TYPE_CHECKING:
import numpy as np


class PandasBackend(DataFrameBackend):
Expand Down Expand Up @@ -44,3 +48,15 @@ def drop_duplicates(self, df, keys, sort_by, ascending: bool = False):

def rename_columns(self, df, columns: dict[str, str]):
return df.rename(columns=columns)

def get_schema(self, df) -> dict[str, "np.dtype"]:
"""Get pandas DataFrame schema as column name to numpy dtype mapping."""
return {col: dtype for col, dtype in df.dtypes.items()}

def get_timestamp_range(self, df, timestamp_column: str) -> tuple:
"""Get min/max of a timestamp column in pandas DataFrame."""
col = df[timestamp_column]
# Ensure it's datetime type
if not pd.api.types.is_datetime64_any_dtype(col):
col = pd.to_datetime(col, utc=True)
return (col.min().to_pydatetime(), col.max().to_pydatetime())
Loading
Loading