diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py index 8faf7eb9ef589..cfb0c70fd71dc 100644 --- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py +++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py @@ -30,6 +30,8 @@ from pyspark.pipelines.flow import Flow from pyspark.pipelines.graph_element_registry import GraphElementRegistry from pyspark.pipelines.source_code_location import SourceCodeLocation +from pyspark.sql.connect.types import pyspark_types_to_proto_types +from pyspark.sql.types import StructType from typing import Any, cast import pyspark.sql.connect.proto as pb2 @@ -47,7 +49,17 @@ def register_dataset(self, dataset: Dataset) -> None: if isinstance(dataset, Table): table_properties = dataset.table_properties partition_cols = dataset.partition_cols - schema = None # TODO + + if isinstance(dataset.schema, str): + schema_string = dataset.schema + schema_data_type = None + elif isinstance(dataset.schema, StructType): + schema_string = None + schema_data_type = pyspark_types_to_proto_types(dataset.schema) + else: + schema_string = None + schema_data_type = None + format = dataset.format if isinstance(dataset, MaterializedView): @@ -62,7 +74,8 @@ def register_dataset(self, dataset: Dataset) -> None: elif isinstance(dataset, TemporaryView): table_properties = None partition_cols = None - schema = None + schema_string = None + schema_data_type = None format = None dataset_type = pb2.DatasetType.TEMPORARY_VIEW else: @@ -78,9 +91,12 @@ def register_dataset(self, dataset: Dataset) -> None: comment=dataset.comment, table_properties=table_properties, partition_cols=partition_cols, - schema=schema, format=format, source_code_location=source_code_location_to_proto(dataset.source_code_location), + # Even though schema_string is not required, the generated Python code seems to + # erroneously think it is required. + schema_string=schema_string, # type: ignore[arg-type] + schema_data_type=schema_data_type, ) command = pb2.Command() command.pipeline_command.define_dataset.CopyFrom(inner_command) diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py b/python/pyspark/sql/connect/proto/pipelines_pb2.py index 849d141f9c498..ba3d2f634a488 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.py +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py @@ -41,7 +41,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xc6\x1b\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02 \x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x12\xa1\x01\n*get_query_function_execution_signal_stream\x18\x07 \x01(\x0b\x32\x44.spark.connect.PipelineCommand.GetQueryFunctionExecutionSignalStreamH\x00R%getQueryFunctionExecutionSignalStream\x12\x88\x01\n!define_flow_query_function_result\x18\x08 \x01(\x0b\x32<.spark.connect.PipelineCommand.DefineFlowQueryFunctionResultH\x00R\x1d\x64\x65\x66ineFlowQueryFunctionResult\x1a\xb4\x02\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\xc4\x05\n\rDefineDataset\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12&\n\x0c\x64\x61taset_name\x18\x02 \x01(\tH\x01R\x0b\x64\x61tasetName\x88\x01\x01\x12\x42\n\x0c\x64\x61taset_type\x18\x03 \x01(\x0e\x32\x1a.spark.connect.DatasetTypeH\x02R\x0b\x64\x61tasetType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x03R\x07\x63omment\x88\x01\x01\x12l\n\x10table_properties\x18\x05 \x03(\x0b\x32\x41.spark.connect.PipelineCommand.DefineDataset.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x06 \x03(\tR\rpartitionCols\x12\x34\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x04R\x06schema\x88\x01\x01\x12\x1b\n\x06\x66ormat\x18\x08 \x01(\tH\x05R\x06\x66ormat\x88\x01\x01\x12X\n\x14source_code_location\x18\t \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x06R\x12sourceCodeLocation\x88\x01\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0f\n\r_dataset_nameB\x0f\n\r_dataset_typeB\n\n\x08_commentB\t\n\x07_schemaB\t\n\x07_formatB\x17\n\x15_source_code_location\x1a\x85\x05\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x01R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x02R\x11targetDatasetName\x88\x01\x01\x12\x38\n\x08relation\x18\x04 \x01(\x0b\x32\x17.spark.connect.RelationH\x03R\x08relation\x88\x01\x01\x12Q\n\x08sql_conf\x18\x05 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12 \n\tclient_id\x18\x06 \x01(\tH\x04R\x08\x63lientId\x88\x01\x01\x12X\n\x14source_code_location\x18\x07 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a:\n\x08Response\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x42\x0c\n\n_flow_nameB\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0b\n\t_relationB\x0c\n\n_client_idB\x17\n\x15_source_code_location\x1a\x97\x02\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x34\n\x16\x66ull_refresh_selection\x18\x02 \x03(\tR\x14\x66ullRefreshSelection\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12+\n\x11refresh_selection\x18\x04 \x03(\tR\x10refreshSelection\x12\x15\n\x03\x64ry\x18\x05 \x01(\x08H\x02R\x03\x64ry\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_allB\x06\n\x04_dry\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_text\x1a\x9e\x01\n%GetQueryFunctionExecutionSignalStream\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tclient_id\x18\x02 \x01(\tH\x01R\x08\x63lientId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_client_id\x1a\xdd\x01\n\x1d\x44\x65\x66ineFlowQueryFunctionResult\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x12/\n\x11\x64\x61taflow_graph_id\x18\x02 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x38\n\x08relation\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationH\x02R\x08relation\x88\x01\x01\x42\x0c\n\n_flow_nameB\x14\n\x12_dataflow_graph_idB\x0b\n\t_relationB\x0e\n\x0c\x63ommand_type"\xf4\x05\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x12n\n\x15\x64\x65\x66ine_dataset_result\x18\x02 \x01(\x0b\x32\x38.spark.connect.PipelineCommandResult.DefineDatasetResultH\x00R\x13\x64\x65\x66ineDatasetResult\x12\x65\n\x12\x64\x65\x66ine_flow_result\x18\x03 \x01(\x0b\x32\x35.spark.connect.PipelineCommandResult.DefineFlowResultH\x00R\x10\x64\x65\x66ineFlowResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x86\x01\n\x13\x44\x65\x66ineDatasetResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifier\x1a\x83\x01\n\x10\x44\x65\x66ineFlowResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifierB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message"z\n\x12SourceCodeLocation\x12 \n\tfile_name\x18\x01 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12$\n\x0bline_number\x18\x02 \x01(\x05H\x01R\nlineNumber\x88\x01\x01\x42\x0c\n\n_file_nameB\x0e\n\x0c_line_number"E\n$PipelineQueryFunctionExecutionSignal\x12\x1d\n\nflow_names\x18\x01 \x03(\tR\tflowNames*a\n\x0b\x44\x61tasetType\x12\x1c\n\x18\x44\x41TASET_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xfb\x1b\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02 \x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x12\xa1\x01\n*get_query_function_execution_signal_stream\x18\x07 \x01(\x0b\x32\x44.spark.connect.PipelineCommand.GetQueryFunctionExecutionSignalStreamH\x00R%getQueryFunctionExecutionSignalStream\x12\x88\x01\n!define_flow_query_function_result\x18\x08 \x01(\x0b\x32<.spark.connect.PipelineCommand.DefineFlowQueryFunctionResultH\x00R\x1d\x64\x65\x66ineFlowQueryFunctionResult\x1a\xb4\x02\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\xf9\x05\n\rDefineDataset\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12&\n\x0c\x64\x61taset_name\x18\x02 \x01(\tH\x02R\x0b\x64\x61tasetName\x88\x01\x01\x12\x42\n\x0c\x64\x61taset_type\x18\x03 \x01(\x0e\x32\x1a.spark.connect.DatasetTypeH\x03R\x0b\x64\x61tasetType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x04R\x07\x63omment\x88\x01\x01\x12l\n\x10table_properties\x18\x05 \x03(\x0b\x32\x41.spark.connect.PipelineCommand.DefineDataset.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x06 \x03(\tR\rpartitionCols\x12\x43\n\x10schema_data_type\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x0eschemaDataType\x12%\n\rschema_string\x18\x08 \x01(\tH\x00R\x0cschemaString\x12\x1b\n\x06\x66ormat\x18\t \x01(\tH\x05R\x06\x66ormat\x88\x01\x01\x12X\n\x14source_code_location\x18\n \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x06R\x12sourceCodeLocation\x88\x01\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x08\n\x06schemaB\x14\n\x12_dataflow_graph_idB\x0f\n\r_dataset_nameB\x0f\n\r_dataset_typeB\n\n\x08_commentB\t\n\x07_formatB\x17\n\x15_source_code_location\x1a\x85\x05\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x01R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x02R\x11targetDatasetName\x88\x01\x01\x12\x38\n\x08relation\x18\x04 \x01(\x0b\x32\x17.spark.connect.RelationH\x03R\x08relation\x88\x01\x01\x12Q\n\x08sql_conf\x18\x05 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12 \n\tclient_id\x18\x06 \x01(\tH\x04R\x08\x63lientId\x88\x01\x01\x12X\n\x14source_code_location\x18\x07 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a:\n\x08Response\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x42\x0c\n\n_flow_nameB\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0b\n\t_relationB\x0c\n\n_client_idB\x17\n\x15_source_code_location\x1a\x97\x02\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x34\n\x16\x66ull_refresh_selection\x18\x02 \x03(\tR\x14\x66ullRefreshSelection\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12+\n\x11refresh_selection\x18\x04 \x03(\tR\x10refreshSelection\x12\x15\n\x03\x64ry\x18\x05 \x01(\x08H\x02R\x03\x64ry\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_allB\x06\n\x04_dry\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_text\x1a\x9e\x01\n%GetQueryFunctionExecutionSignalStream\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tclient_id\x18\x02 \x01(\tH\x01R\x08\x63lientId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_client_id\x1a\xdd\x01\n\x1d\x44\x65\x66ineFlowQueryFunctionResult\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x12/\n\x11\x64\x61taflow_graph_id\x18\x02 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x38\n\x08relation\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationH\x02R\x08relation\x88\x01\x01\x42\x0c\n\n_flow_nameB\x14\n\x12_dataflow_graph_idB\x0b\n\t_relationB\x0e\n\x0c\x63ommand_type"\xf4\x05\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x12n\n\x15\x64\x65\x66ine_dataset_result\x18\x02 \x01(\x0b\x32\x38.spark.connect.PipelineCommandResult.DefineDatasetResultH\x00R\x13\x64\x65\x66ineDatasetResult\x12\x65\n\x12\x64\x65\x66ine_flow_result\x18\x03 \x01(\x0b\x32\x35.spark.connect.PipelineCommandResult.DefineFlowResultH\x00R\x10\x64\x65\x66ineFlowResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x86\x01\n\x13\x44\x65\x66ineDatasetResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifier\x1a\x83\x01\n\x10\x44\x65\x66ineFlowResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifierB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message"z\n\x12SourceCodeLocation\x12 \n\tfile_name\x18\x01 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12$\n\x0bline_number\x18\x02 \x01(\x05H\x01R\nlineNumber\x88\x01\x01\x42\x0c\n\n_file_nameB\x0e\n\x0c_line_number"E\n$PipelineQueryFunctionExecutionSignal\x12\x1d\n\nflow_names\x18\x01 \x03(\tR\tflowNames*a\n\x0b\x44\x61tasetType\x12\x1c\n\x18\x44\x41TASET_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -60,10 +60,10 @@ _globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_options = b"8\001" _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options = b"8\001" - _globals["_DATASETTYPE"]._serialized_start = 4843 - _globals["_DATASETTYPE"]._serialized_end = 4940 + _globals["_DATASETTYPE"]._serialized_start = 4896 + _globals["_DATASETTYPE"]._serialized_end = 4993 _globals["_PIPELINECOMMAND"]._serialized_start = 168 - _globals["_PIPELINECOMMAND"]._serialized_end = 3694 + _globals["_PIPELINECOMMAND"]._serialized_end = 3747 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 1050 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1358 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start = 1259 @@ -71,37 +71,37 @@ _globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_start = 1360 _globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_end = 1450 _globals["_PIPELINECOMMAND_DEFINEDATASET"]._serialized_start = 1453 - _globals["_PIPELINECOMMAND_DEFINEDATASET"]._serialized_end = 2161 - _globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_start = 1980 - _globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_end = 2046 - _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 2164 - _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 2809 + _globals["_PIPELINECOMMAND_DEFINEDATASET"]._serialized_end = 2214 + _globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_start = 2034 + _globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_end = 2100 + _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 2217 + _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 2862 _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start = 1259 _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 1317 - _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 2639 - _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 2697 - _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2812 - _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 3091 - _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 3094 - _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 3293 - _globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start = 3296 - _globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end = 3454 - _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start = 3457 - _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end = 3678 - _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 3697 - _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 4453 - _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start = 4069 - _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 4167 - _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_start = 4170 - _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_end = 4304 - _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start = 4307 - _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 4438 - _globals["_PIPELINEEVENTRESULT"]._serialized_start = 4455 - _globals["_PIPELINEEVENTRESULT"]._serialized_end = 4528 - _globals["_PIPELINEEVENT"]._serialized_start = 4530 - _globals["_PIPELINEEVENT"]._serialized_end = 4646 - _globals["_SOURCECODELOCATION"]._serialized_start = 4648 - _globals["_SOURCECODELOCATION"]._serialized_end = 4770 - _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 4772 - _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 4841 + _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 2692 + _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 2750 + _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2865 + _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 3144 + _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 3147 + _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 3346 + _globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start = 3349 + _globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end = 3507 + _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start = 3510 + _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end = 3731 + _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 3750 + _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 4506 + _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start = 4122 + _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 4220 + _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_start = 4223 + _globals["_PIPELINECOMMANDRESULT_DEFINEDATASETRESULT"]._serialized_end = 4357 + _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start = 4360 + _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 4491 + _globals["_PIPELINEEVENTRESULT"]._serialized_start = 4508 + _globals["_PIPELINEEVENTRESULT"]._serialized_end = 4581 + _globals["_PIPELINEEVENT"]._serialized_start = 4583 + _globals["_PIPELINEEVENT"]._serialized_end = 4699 + _globals["_SOURCECODELOCATION"]._serialized_start = 4701 + _globals["_SOURCECODELOCATION"]._serialized_end = 4823 + _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 4825 + _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 4894 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi index b5ed1c216a837..575e15e0cb585 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi @@ -231,7 +231,8 @@ class PipelineCommand(google.protobuf.message.Message): COMMENT_FIELD_NUMBER: builtins.int TABLE_PROPERTIES_FIELD_NUMBER: builtins.int PARTITION_COLS_FIELD_NUMBER: builtins.int - SCHEMA_FIELD_NUMBER: builtins.int + SCHEMA_DATA_TYPE_FIELD_NUMBER: builtins.int + SCHEMA_STRING_FIELD_NUMBER: builtins.int FORMAT_FIELD_NUMBER: builtins.int SOURCE_CODE_LOCATION_FIELD_NUMBER: builtins.int dataflow_graph_id: builtins.str @@ -255,8 +256,8 @@ class PipelineCommand(google.protobuf.message.Message): dataset_type == MATERIALIZED_VIEW. """ @property - def schema(self) -> pyspark.sql.connect.proto.types_pb2.DataType: - """Schema for the dataset. If unset, this will be inferred from incoming flows.""" + def schema_data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... + schema_string: builtins.str format: builtins.str """The output table format of the dataset. Only applies to dataset_type == TABLE and dataset_type == MATERIALIZED_VIEW. @@ -273,7 +274,8 @@ class PipelineCommand(google.protobuf.message.Message): comment: builtins.str | None = ..., table_properties: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., partition_cols: collections.abc.Iterable[builtins.str] | None = ..., - schema: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., + schema_data_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ..., + schema_string: builtins.str = ..., format: builtins.str | None = ..., source_code_location: global___SourceCodeLocation | None = ..., ) -> None: ... @@ -290,8 +292,6 @@ class PipelineCommand(google.protobuf.message.Message): b"_dataset_type", "_format", b"_format", - "_schema", - b"_schema", "_source_code_location", b"_source_code_location", "comment", @@ -306,6 +306,10 @@ class PipelineCommand(google.protobuf.message.Message): b"format", "schema", b"schema", + "schema_data_type", + b"schema_data_type", + "schema_string", + b"schema_string", "source_code_location", b"source_code_location", ], @@ -323,8 +327,6 @@ class PipelineCommand(google.protobuf.message.Message): b"_dataset_type", "_format", b"_format", - "_schema", - b"_schema", "_source_code_location", b"_source_code_location", "comment", @@ -341,6 +343,10 @@ class PipelineCommand(google.protobuf.message.Message): b"partition_cols", "schema", b"schema", + "schema_data_type", + b"schema_data_type", + "schema_string", + b"schema_string", "source_code_location", b"source_code_location", "table_properties", @@ -369,16 +375,16 @@ class PipelineCommand(google.protobuf.message.Message): self, oneof_group: typing_extensions.Literal["_format", b"_format"] ) -> typing_extensions.Literal["format"] | None: ... @typing.overload - def WhichOneof( - self, oneof_group: typing_extensions.Literal["_schema", b"_schema"] - ) -> typing_extensions.Literal["schema"] | None: ... - @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal[ "_source_code_location", b"_source_code_location" ], ) -> typing_extensions.Literal["source_code_location"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["schema", b"schema"] + ) -> typing_extensions.Literal["schema_data_type", "schema_string"] | None: ... class DefineFlow(google.protobuf.message.Message): """Request to define a flow targeting a dataset.""" diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto index 16d211f9f72d3..130dcd54e59b4 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto @@ -81,14 +81,17 @@ message PipelineCommand { repeated string partition_cols = 6; // Schema for the dataset. If unset, this will be inferred from incoming flows. - optional spark.connect.DataType schema = 7; + oneof schema { + spark.connect.DataType schema_data_type = 7; + string schema_string = 8; + } // The output table format of the dataset. Only applies to dataset_type == TABLE and // dataset_type == MATERIALIZED_VIEW. - optional string format = 8; + optional string format = 9; // The location in source code that this dataset was defined. - optional SourceCodeLocation source_code_location = 9; + optional SourceCodeLocation source_code_location = 10; } // Request to define a flow targeting a dataset. diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 01402c64e8a2a..21f6c9492f68b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -188,10 +188,17 @@ private[connect] object PipelinesHandler extends Logging { Table( identifier = qualifiedIdentifier, comment = Option(dataset.getComment), - specifiedSchema = Option.when(dataset.hasSchema)( - DataTypeProtoConverter - .toCatalystType(dataset.getSchema) - .asInstanceOf[StructType]), + specifiedSchema = dataset.getSchemaCase match { + case proto.PipelineCommand.DefineDataset.SchemaCase.SCHEMA_DATA_TYPE => + Some( + DataTypeProtoConverter + .toCatalystType(dataset.getSchemaDataType) + .asInstanceOf[StructType]) + case proto.PipelineCommand.DefineDataset.SchemaCase.SCHEMA_STRING => + Some(StructType.fromDDL(dataset.getSchemaString)) + case proto.PipelineCommand.DefineDataset.SchemaCase.SCHEMA_NOT_SET => + None + }, partitionCols = Option(dataset.getPartitionColsList.asScala.toSeq) .filter(_.nonEmpty), properties = dataset.getTablePropertiesMap.asScala.toMap, diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index b99615062d456..b5461c415b0de 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.pipelines.common.FlowStatus import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl, QueryOrigin, QueryOriginType} import org.apache.spark.sql.pipelines.logging.EventLevel import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} +import org.apache.spark.sql.types.StructType /** * Test suite that starts a Spark Connect server and executes Spark Declarative Pipelines Python @@ -694,6 +695,82 @@ class PythonPipelineSuite parameters = Map.empty) } + test("table with string schema") { + val graph = buildGraph(""" + |from pyspark.sql.functions import lit + | + |@dp.materialized_view(schema="id LONG, name STRING") + |def table_with_string_schema(): + | return spark.range(5).withColumn("name", lit("test")) + |""".stripMargin) + .resolve() + .validate() + + assert(graph.flows.size == 1) + assert(graph.tables.size == 1) + + val table = graph.table(graphIdentifier("table_with_string_schema")) + assert(table.specifiedSchema.isDefined) + assert(table.specifiedSchema.get == StructType.fromDDL("id LONG, name STRING")) + } + + test("table with StructType schema") { + val graph = buildGraph(""" + |from pyspark.sql.types import StructType, StructField, LongType, StringType + |from pyspark.sql.functions import lit + | + |@dp.materialized_view(schema=StructType([ + | StructField("id", LongType(), True), + | StructField("name", StringType(), True) + |])) + |def table_with_struct_schema(): + | return spark.range(5).withColumn("name", lit("test")) + |""".stripMargin) + .resolve() + .validate() + + assert(graph.flows.size == 1) + assert(graph.tables.size == 1) + + val table = graph.table(graphIdentifier("table_with_struct_schema")) + assert(table.specifiedSchema.isDefined) + assert(table.specifiedSchema.get == StructType.fromDDL("id LONG, name STRING")) + } + + test("string schema validation error - schema mismatch") { + val graph = buildGraph(""" + |from pyspark.sql.functions import lit + | + |@dp.materialized_view(schema="id LONG, name STRING") + |def table_with_wrong_schema(): + | return spark.range(5).withColumn("wrong_column", lit("test")) + |""".stripMargin) + .resolve() + + val ex = intercept[AnalysisException] { graph.validate() } + assert(ex.getMessage.contains("has a user-specified schema that is incompatible")) + assert(ex.getMessage.contains("table_with_wrong_schema")) + } + + test("StructType schema validation error - schema mismatch") { + val graph = buildGraph(""" + |from pyspark.sql.types import StructType, StructField, LongType, StringType + |from pyspark.sql.functions import lit + | + |@dp.materialized_view(schema=StructType([ + | StructField("id", LongType(), True), + | StructField("name", StringType(), True) + |])) + |def table_with_wrong_struct_schema(): + | return spark.range(5).withColumn("different_column", lit("test")) + |""".stripMargin) + .resolve() + + val ex = intercept[AnalysisException] { graph.validate() } + assert(ex.getMessage.contains("has a user-specified schema that is incompatible")) + assert(ex.getMessage.contains("table_with_wrong_struct_schema")) + } + /** * Executes Python code in a separate process and returns the exit code. *