@@ -438,54 +438,13 @@ def to_csv_file(self) -> Tuple[str, str]:
438438 os .remove (local_file_name )
439439 temp_table_name = f'dataframe_{ temp_id .replace ("-" , "_" )} '
440440 self ._create_temp_table (temp_table_name , desired_s3_folder )
441- base_features = list (self ._base .columns )
442- event_time_identifier_feature_dtype = self ._base [
443- self ._event_time_identifier_feature_name
444- ].dtypes
445- self ._event_time_identifier_feature_type = (
446- FeatureGroup .DTYPE_TO_FEATURE_DEFINITION_CLS_MAP .get (
447- str (event_time_identifier_feature_dtype ), None
448- )
449- )
450- query_string = self ._construct_query_string (
451- FeatureGroupToBeMerged (
452- base_features ,
453- self ._included_feature_names if self ._included_feature_names else base_features ,
454- self ._included_feature_names if self ._included_feature_names else base_features ,
455- _DEFAULT_CATALOG ,
456- _DEFAULT_DATABASE ,
457- temp_table_name ,
458- self ._record_identifier_feature_name ,
459- FeatureDefinition (
460- self ._event_time_identifier_feature_name ,
461- self ._event_time_identifier_feature_type ,
462- ),
463- None ,
464- TableType .DATA_FRAME ,
465- )
466- )
467- query_result = self ._run_query (query_string , _DEFAULT_CATALOG , _DEFAULT_DATABASE )
441+ query_result = self ._run_query (* self ._to_athena_query (temp_table_name = temp_table_name ))
468442 # TODO: cleanup temp table, need more clarification, keep it for now
469443 return query_result .get ("QueryExecution" , {}).get ("ResultConfiguration" , {}).get (
470444 "OutputLocation" , None
471445 ), query_result .get ("QueryExecution" , {}).get ("Query" , None )
472446 if isinstance (self ._base , FeatureGroup ):
473- base_feature_group = construct_feature_group_to_be_merged (
474- self ._base , self ._included_feature_names
475- )
476- self ._record_identifier_feature_name = base_feature_group .record_identifier_feature_name
477- self ._event_time_identifier_feature_name = (
478- base_feature_group .event_time_identifier_feature .feature_name
479- )
480- self ._event_time_identifier_feature_type = (
481- base_feature_group .event_time_identifier_feature .feature_type
482- )
483- query_string = self ._construct_query_string (base_feature_group )
484- query_result = self ._run_query (
485- query_string ,
486- base_feature_group .catalog ,
487- base_feature_group .database ,
488- )
447+ query_result = self ._run_query (* self ._to_athena_query ())
489448 return query_result .get ("QueryExecution" , {}).get ("ResultConfiguration" , {}).get (
490449 "OutputLocation" , None
491450 ), query_result .get ("QueryExecution" , {}).get ("Query" , None )
@@ -1058,6 +1017,67 @@ def _construct_athena_table_column_string(self, column: str) -> str:
10581017 raise RuntimeError (f"The dataframe type { dataframe_type } is not supported yet." )
10591018 return f"{ column } { self ._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP .get (str (dataframe_type ), None )} "
10601019
1020+ def _to_athena_query (self , temp_table_name : str = None ) -> Tuple [str , str , str ]:
1021+ """Internal method for constructing an Athena query.
1022+
1023+ Args:
1024+ temp_table_name (str): The temporary Athena table name of the base pandas.DataFrame. Defaults to None.
1025+
1026+ Returns:
1027+ The query string.
1028+ The name of the catalog to be used in the query execution.
1029+ The database to be used in the query execution.
1030+
1031+ Raises:
1032+ ValueError: temp_table_name must be provided if the base is a pandas.DataFrame.
1033+ """
1034+ if isinstance (self ._base , pd .DataFrame ):
1035+ if temp_table_name is None :
1036+ raise ValueError ("temp_table_name must be provided for a pandas.DataFrame base." )
1037+ base_features = list (self ._base .columns )
1038+ event_time_identifier_feature_dtype = self ._base [
1039+ self ._event_time_identifier_feature_name
1040+ ].dtypes
1041+ self ._event_time_identifier_feature_type = (
1042+ FeatureGroup .DTYPE_TO_FEATURE_DEFINITION_CLS_MAP .get (
1043+ str (event_time_identifier_feature_dtype ), None
1044+ )
1045+ )
1046+ catalog = _DEFAULT_CATALOG
1047+ database = _DEFAULT_DATABASE
1048+ query_string = self ._construct_query_string (
1049+ FeatureGroupToBeMerged (
1050+ base_features ,
1051+ self ._included_feature_names if self ._included_feature_names else base_features ,
1052+ self ._included_feature_names if self ._included_feature_names else base_features ,
1053+ catalog ,
1054+ database ,
1055+ temp_table_name ,
1056+ self ._record_identifier_feature_name ,
1057+ FeatureDefinition (
1058+ self ._event_time_identifier_feature_name ,
1059+ self ._event_time_identifier_feature_type ,
1060+ ),
1061+ None ,
1062+ TableType .DATA_FRAME ,
1063+ )
1064+ )
1065+ if isinstance (self ._base , FeatureGroup ):
1066+ base_feature_group = construct_feature_group_to_be_merged (
1067+ self ._base , self ._included_feature_names
1068+ )
1069+ self ._record_identifier_feature_name = base_feature_group .record_identifier_feature_name
1070+ self ._event_time_identifier_feature_name = (
1071+ base_feature_group .event_time_identifier_feature .feature_name
1072+ )
1073+ self ._event_time_identifier_feature_type = (
1074+ base_feature_group .event_time_identifier_feature .feature_type
1075+ )
1076+ catalog = base_feature_group .catalog
1077+ database = base_feature_group .database
1078+ query_string = self ._construct_query_string (base_feature_group )
1079+ return query_string , catalog , database
1080+
10611081 def _run_query (self , query_string : str , catalog : str , database : str ) -> Dict [str , Any ]:
10621082 """Internal method for execute Athena query, wait for query finish and get query result.
10631083
0 commit comments