@@ -438,53 +438,16 @@ def to_csv_file(self) -> Tuple[str, str]:
438438 os .remove (local_file_name )
439439 temp_table_name = f'dataframe_{ temp_id .replace ("-" , "_" )} '
440440 self ._create_temp_table (temp_table_name , desired_s3_folder )
441- base_features = list (self ._base .columns )
442- event_time_identifier_feature_dtype = self ._base [
443- self ._event_time_identifier_feature_name
444- ].dtypes
445- self ._event_time_identifier_feature_type = (
446- FeatureGroup .DTYPE_TO_FEATURE_DEFINITION_CLS_MAP .get (
447- str (event_time_identifier_feature_dtype ), None
448- )
449- )
450- query_string = self ._construct_query_string (
451- FeatureGroupToBeMerged (
452- base_features ,
453- self ._included_feature_names if self ._included_feature_names else base_features ,
454- self ._included_feature_names if self ._included_feature_names else base_features ,
455- _DEFAULT_CATALOG ,
456- _DEFAULT_DATABASE ,
457- temp_table_name ,
458- self ._record_identifier_feature_name ,
459- FeatureDefinition (
460- self ._event_time_identifier_feature_name ,
461- self ._event_time_identifier_feature_type ,
462- ),
463- None ,
464- TableType .DATA_FRAME ,
465- )
441+ query_result = self ._run_query (
442+ ** self ._to_athena_query (temp_table_name = temp_table_name )
466443 )
467- query_result = self ._run_query (query_string , _DEFAULT_CATALOG , _DEFAULT_DATABASE )
468444 # TODO: cleanup temp table, need more clarification, keep it for now
469445 return query_result .get ("QueryExecution" , {}).get ("ResultConfiguration" , {}).get (
470446 "OutputLocation" , None
471447 ), query_result .get ("QueryExecution" , {}).get ("Query" , None )
472448 if isinstance (self ._base , FeatureGroup ):
473- base_feature_group = construct_feature_group_to_be_merged (
474- self ._base , self ._included_feature_names
475- )
476- self ._record_identifier_feature_name = base_feature_group .record_identifier_feature_name
477- self ._event_time_identifier_feature_name = (
478- base_feature_group .event_time_identifier_feature .feature_name
479- )
480- self ._event_time_identifier_feature_type = (
481- base_feature_group .event_time_identifier_feature .feature_type
482- )
483- query_string = self ._construct_query_string (base_feature_group )
484449 query_result = self ._run_query (
485- query_string ,
486- base_feature_group .catalog ,
487- base_feature_group .database ,
450+ ** self ._to_athena_query ()
488451 )
489452 return query_result .get ("QueryExecution" , {}).get ("ResultConfiguration" , {}).get (
490453 "OutputLocation" , None
@@ -1058,6 +1021,67 @@ def _construct_athena_table_column_string(self, column: str) -> str:
10581021 raise RuntimeError (f"The dataframe type { dataframe_type } is not supported yet." )
10591022 return f"{ column } { self ._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP .get (str (dataframe_type ), None )} "
10601023
1024+ def _to_athena_query (self , temp_table_name : str = None ) -> Tuple [str , str , str ]:
1025+ """Internal method for constructing an Athena query.
1026+
1027+ Args:
1028+ temp_table_name (str): The temporary Athena table name of the base pandas.DataFrame. Defaults to None.
1029+
1030+ Returns:
1031+ The query string.
1032+ The name of the catalog to be used in the query execution.
1033+ The database to be used in the query execution.
1034+
1035+ Raises:
1036+ ValueError: temp_table_name must be provided if the base is a pandas.DataFrame.
1037+ """
1038+ if isinstance (self ._base , pd .DataFrame ):
1039+ if temp_table_name is None :
1040+ raise ValueError ("temp_table_name must be provided for a pandas.DataFrame base." )
1041+ base_features = list (self ._base .columns )
1042+ event_time_identifier_feature_dtype = self ._base [
1043+ self ._event_time_identifier_feature_name
1044+ ].dtypes
1045+ self ._event_time_identifier_feature_type = (
1046+ FeatureGroup .DTYPE_TO_FEATURE_DEFINITION_CLS_MAP .get (
1047+ str (event_time_identifier_feature_dtype ), None
1048+ )
1049+ )
1050+ catalog = _DEFAULT_CATALOG
1051+ database = _DEFAULT_DATABASE
1052+ query_string = self ._construct_query_string (
1053+ FeatureGroupToBeMerged (
1054+ base_features ,
1055+ self ._included_feature_names if self ._included_feature_names else base_features ,
1056+ self ._included_feature_names if self ._included_feature_names else base_features ,
1057+ catalog ,
1058+ database ,
1059+ temp_table_name ,
1060+ self ._record_identifier_feature_name ,
1061+ FeatureDefinition (
1062+ self ._event_time_identifier_feature_name ,
1063+ self ._event_time_identifier_feature_type ,
1064+ ),
1065+ None ,
1066+ TableType .DATA_FRAME ,
1067+ )
1068+ )
1069+ if isinstance (self ._base , FeatureGroup ):
1070+ base_feature_group = construct_feature_group_to_be_merged (
1071+ self ._base , self ._included_feature_names
1072+ )
1073+ self ._record_identifier_feature_name = base_feature_group .record_identifier_feature_name
1074+ self ._event_time_identifier_feature_name = (
1075+ base_feature_group .event_time_identifier_feature .feature_name
1076+ )
1077+ self ._event_time_identifier_feature_type = (
1078+ base_feature_group .event_time_identifier_feature .feature_type
1079+ )
1080+ catalog = base_feature_group .catalog
1081+ database = base_feature_group .database
1082+ query_string = self ._construct_query_string (base_feature_group )
1083+ return query_string , catalog , database
1084+
10611085 def _run_query (self , query_string : str , catalog : str , database : str ) -> Dict [str , Any ]:
10621086 """Internal method for execute Athena query, wait for query finish and get query result.
10631087
0 commit comments