30
30
from pyspark .pipelines .flow import Flow
31
31
from pyspark .pipelines .graph_element_registry import GraphElementRegistry
32
32
from pyspark .pipelines .source_code_location import SourceCodeLocation
33
+ from pyspark .sql .connect .types import pyspark_types_to_proto_types
34
+ from pyspark .sql .types import StructType
33
35
from typing import Any , cast
34
36
import pyspark .sql .connect .proto as pb2
35
37
@@ -47,7 +49,7 @@ def register_dataset(self, dataset: Dataset) -> None:
47
49
if isinstance (dataset , Table ):
48
50
table_properties = dataset .table_properties
49
51
partition_cols = dataset .partition_cols
50
- schema = None # TODO
52
+ schema = dataset . schema
51
53
format = dataset .format
52
54
53
55
if isinstance (dataset , MaterializedView ):
@@ -71,17 +73,27 @@ def register_dataset(self, dataset: Dataset) -> None:
71
73
messageParameters = {"dataset_type" : type (dataset ).__name__ },
72
74
)
73
75
74
- inner_command = pb2 .PipelineCommand .DefineDataset (
75
- dataflow_graph_id = self ._dataflow_graph_id ,
76
- dataset_name = dataset .name ,
77
- dataset_type = dataset_type ,
78
- comment = dataset .comment ,
79
- table_properties = table_properties ,
80
- partition_cols = partition_cols ,
81
- schema = schema ,
82
- format = format ,
83
- source_code_location = source_code_location_to_proto (dataset .source_code_location ),
84
- )
76
+ define_dataset_kwargs = {
77
+ "dataflow_graph_id" : self ._dataflow_graph_id ,
78
+ "dataset_name" : dataset .name ,
79
+ "dataset_type" : dataset_type ,
80
+ "comment" : dataset .comment ,
81
+ "table_properties" : table_properties ,
82
+ "partition_cols" : partition_cols ,
83
+ "format" : format ,
84
+ "source_code_location" : source_code_location_to_proto (dataset .source_code_location ),
85
+ }
86
+
87
+ if schema is not None :
88
+ if isinstance (schema , str ):
89
+ define_dataset_kwargs ["schema_string" ] = schema
90
+ elif isinstance (schema , StructType ):
91
+ define_dataset_kwargs ["data_type" ] = pyspark_types_to_proto_types (schema )
92
+ else :
93
+ # For other DataType objects
94
+ define_dataset_kwargs ["data_type" ] = pyspark_types_to_proto_types (schema )
95
+
96
+ inner_command = pb2 .PipelineCommand .DefineDataset (** define_dataset_kwargs )
85
97
command = pb2 .Command ()
86
98
command .pipeline_command .define_dataset .CopyFrom (inner_command )
87
99
self ._client .execute_command (command )
0 commit comments