@@ -173,7 +173,9 @@ def join_with_aggregated_features(batch: pd.DataFrame) -> pd.DataFrame:
173173 return result
174174
175175 joined_dataset = entity_dataset .map_batches (
176- join_with_aggregated_features , batch_format = "pandas"
176+ join_with_aggregated_features ,
177+ batch_format = "pandas" ,
178+ concurrency = self .config .max_workers or 12 ,
177179 )
178180 else :
179181 if feature_size <= self .config .broadcast_join_threshold_mb * 1024 * 1024 :
@@ -274,8 +276,8 @@ def apply_filters(batch: pd.DataFrame) -> pd.DataFrame:
274276 else :
275277 # Use current time for TTL calculation (real-time retrieval)
276278 # Check if timestamp column is timezone-aware
277- if pd . api . types . is_datetime64tz_dtype (
278- filtered_batch [timestamp_col ]
279+ if isinstance (
280+ filtered_batch [timestamp_col ]. dtype , pd . DatetimeTZDtype
279281 ):
280282 # Use timezone-aware current time
281283 current_time = datetime .now (timezone .utc )
@@ -517,31 +519,59 @@ def execute(self, context: ExecutionContext) -> DAGValue:
517519 input_value .assert_format (DAGFormat .RAY )
518520 dataset : Dataset = input_value .data
519521
520- transformation_serialized = None
521- if hasattr (self .transformation , "udf" ) and callable (self .transformation .udf ):
522- transformation_serialized = dill .dumps (self .transformation .udf )
523- elif callable (self .transformation ):
524- transformation_serialized = dill .dumps (self .transformation )
522+ # Check transformation mode
523+ from feast .transformation .mode import TransformationMode
525524
526- @safe_batch_processor
527- def apply_transformation_with_serialized_udf (
528- batch : pd .DataFrame ,
529- ) -> pd .DataFrame :
530- """Apply the transformation using pre-serialized UDF."""
531- if transformation_serialized :
532- transformation_func = dill .loads (transformation_serialized )
533- transformed_batch = transformation_func (batch )
525+ transformation_mode = getattr (
526+ self .transformation , "mode" , TransformationMode .PYTHON
527+ )
528+ is_ray_native = transformation_mode in (TransformationMode .RAY , "ray" )
529+ if is_ray_native :
530+ transformation_func = None
531+ if hasattr (self .transformation , "udf" ) and callable (
532+ self .transformation .udf
533+ ):
534+ transformation_func = self .transformation .udf
535+ elif callable (self .transformation ):
536+ transformation_func = self .transformation
537+
538+ if transformation_func :
539+ transformed_dataset = transformation_func (dataset )
534540 else :
535541 logger .warning (
536- "No serialized transformation available, returning original batch "
542+ "No transformation function available in RAY mode , returning original dataset "
537543 )
538- transformed_batch = batch
544+ transformed_dataset = dataset
545+ else :
546+ transformation_serialized = None
547+ if hasattr (self .transformation , "udf" ) and callable (
548+ self .transformation .udf
549+ ):
550+ transformation_serialized = dill .dumps (self .transformation .udf )
551+ elif callable (self .transformation ):
552+ transformation_serialized = dill .dumps (self .transformation )
539553
540- return transformed_batch
554+ @safe_batch_processor
555+ def apply_transformation_with_serialized_udf (
556+ batch : pd .DataFrame ,
557+ ) -> pd .DataFrame :
558+ """Apply the transformation using pre-serialized UDF."""
559+ if transformation_serialized :
560+ transformation_func = dill .loads (transformation_serialized )
561+ transformed_batch = transformation_func (batch )
562+ else :
563+ logger .warning (
564+ "No serialized transformation available, returning original batch"
565+ )
566+ transformed_batch = batch
541567
542- transformed_dataset = dataset .map_batches (
543- apply_transformation_with_serialized_udf , batch_format = "pandas"
544- )
568+ return transformed_batch
569+
570+ transformed_dataset = dataset .map_batches (
571+ apply_transformation_with_serialized_udf ,
572+ batch_format = "pandas" ,
573+ concurrency = self .config .max_workers or 12 ,
574+ )
545575
546576 return DAGValue (
547577 data = transformed_dataset ,
@@ -598,7 +628,9 @@ def apply_transformation(batch: pd.DataFrame) -> pd.DataFrame:
598628 return transformation_func (batch )
599629
600630 transformed_dataset = parent_value .data .map_batches (
601- apply_transformation
631+ apply_transformation ,
632+ batch_format = "pandas" ,
633+ concurrency = self .config .max_workers or 12 ,
602634 )
603635 return DAGValue (
604636 data = transformed_dataset ,
@@ -630,9 +662,11 @@ def __init__(
630662 name : str ,
631663 feature_view : Union [BatchFeatureView , StreamFeatureView , FeatureView ],
632664 inputs = None ,
665+ config : Optional [RayComputeEngineConfig ] = None ,
633666 ):
634667 super ().__init__ (name , inputs = inputs )
635668 self .feature_view = feature_view
669+ self .config = config
636670
637671 def execute (self , context : ExecutionContext ) -> DAGValue :
638672 """Execute the write operation."""
@@ -676,7 +710,9 @@ def write_batch_with_serialized_artifacts(batch: pd.DataFrame) -> pd.DataFrame:
676710 return batch
677711
678712 written_dataset = dataset .map_batches (
679- write_batch_with_serialized_artifacts , batch_format = "pandas"
713+ write_batch_with_serialized_artifacts ,
714+ batch_format = "pandas" ,
715+ concurrency = self .config .max_workers if self .config else 12 ,
680716 )
681717 written_dataset = written_dataset .materialize ()
682718
0 commit comments