Skip to content

Commit cb8ba44

Browse files
committed
table schema string
1 parent 65ff85a commit cb8ba44

File tree

6 files changed

+173
-64
lines changed

6 files changed

+173
-64
lines changed

python/pyspark/pipelines/spark_connect_graph_element_registry.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from pyspark.pipelines.flow import Flow
3131
from pyspark.pipelines.graph_element_registry import GraphElementRegistry
3232
from pyspark.pipelines.source_code_location import SourceCodeLocation
33+
from pyspark.sql.connect.types import pyspark_types_to_proto_types
34+
from pyspark.sql.types import StructType
3335
from typing import Any, cast
3436
import pyspark.sql.connect.proto as pb2
3537

@@ -47,7 +49,7 @@ def register_dataset(self, dataset: Dataset) -> None:
4749
if isinstance(dataset, Table):
4850
table_properties = dataset.table_properties
4951
partition_cols = dataset.partition_cols
50-
schema = None # TODO
52+
schema = dataset.schema
5153
format = dataset.format
5254

5355
if isinstance(dataset, MaterializedView):
@@ -71,17 +73,27 @@ def register_dataset(self, dataset: Dataset) -> None:
7173
messageParameters={"dataset_type": type(dataset).__name__},
7274
)
7375

74-
inner_command = pb2.PipelineCommand.DefineDataset(
75-
dataflow_graph_id=self._dataflow_graph_id,
76-
dataset_name=dataset.name,
77-
dataset_type=dataset_type,
78-
comment=dataset.comment,
79-
table_properties=table_properties,
80-
partition_cols=partition_cols,
81-
schema=schema,
82-
format=format,
83-
source_code_location=source_code_location_to_proto(dataset.source_code_location),
84-
)
76+
define_dataset_kwargs = {
77+
"dataflow_graph_id": self._dataflow_graph_id,
78+
"dataset_name": dataset.name,
79+
"dataset_type": dataset_type,
80+
"comment": dataset.comment,
81+
"table_properties": table_properties,
82+
"partition_cols": partition_cols,
83+
"format": format,
84+
"source_code_location": source_code_location_to_proto(dataset.source_code_location),
85+
}
86+
87+
if schema is not None:
88+
if isinstance(schema, str):
89+
define_dataset_kwargs["schema_string"] = schema
90+
elif isinstance(schema, StructType):
91+
define_dataset_kwargs["data_type"] = pyspark_types_to_proto_types(schema)
92+
else:
93+
# For other DataType objects
94+
define_dataset_kwargs["data_type"] = pyspark_types_to_proto_types(schema)
95+
96+
inner_command = pb2.PipelineCommand.DefineDataset(**define_dataset_kwargs)
8597
command = pb2.Command()
8698
command.pipeline_command.define_dataset.CopyFrom(inner_command)
8799
self._client.execute_command(command)

0 commit comments

Comments
 (0)