Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions python/pyspark/pipelines/spark_connect_graph_element_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from pyspark.pipelines.flow import Flow
from pyspark.pipelines.graph_element_registry import GraphElementRegistry
from pyspark.pipelines.source_code_location import SourceCodeLocation
from pyspark.sql.connect.types import pyspark_types_to_proto_types
from pyspark.sql.types import StructType
from typing import Any, cast
import pyspark.sql.connect.proto as pb2

Expand All @@ -47,7 +49,17 @@ def register_dataset(self, dataset: Dataset) -> None:
if isinstance(dataset, Table):
table_properties = dataset.table_properties
partition_cols = dataset.partition_cols
schema = None # TODO

if isinstance(dataset.schema, str):
schema_string = dataset.schema
schema_data_type = None
elif isinstance(dataset.schema, StructType):
schema_string = None
schema_data_type = pyspark_types_to_proto_types(dataset.schema)
else:
schema_string = None
schema_data_type = None

format = dataset.format

if isinstance(dataset, MaterializedView):
Expand All @@ -62,7 +74,8 @@ def register_dataset(self, dataset: Dataset) -> None:
elif isinstance(dataset, TemporaryView):
table_properties = None
partition_cols = None
schema = None
schema_string = None
schema_data_type = None
format = None
dataset_type = pb2.DatasetType.TEMPORARY_VIEW
else:
Expand All @@ -78,9 +91,12 @@ def register_dataset(self, dataset: Dataset) -> None:
comment=dataset.comment,
table_properties=table_properties,
partition_cols=partition_cols,
schema=schema,
format=format,
source_code_location=source_code_location_to_proto(dataset.source_code_location),
# Even though schema_string is not required, the generated Python code seems to
# erroneously think it is required.
schema_string=schema_string, # type: ignore[arg-type]
schema_data_type=schema_data_type,
)
command = pb2.Command()
command.pipeline_command.define_dataset.CopyFrom(inner_command)
Expand Down
Loading