Skip to content

Commit 99d11c6

Browse files
committed
feat(write_table): added the write table builder
1 parent 7e3aea1 commit 99d11c6

File tree

2 files changed

+47
-59
lines changed

2 files changed

+47
-59
lines changed

src/substrait/builders/plan.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,16 @@
77

88
from typing import Iterable, Optional, Union, Callable
99

10-
import substrait.gen.proto.algebra_pb2 as stalg
1110
from substrait.gen.proto.extensions.extensions_pb2 import AdvancedExtension
11+
import substrait.gen.proto.algebra_pb2 as stalg
12+
import substrait.gen.proto.extended_expression_pb2 as stee
1213
import substrait.gen.proto.plan_pb2 as stp
1314
import substrait.gen.proto.type_pb2 as stt
14-
import substrait.gen.proto.extended_expression_pb2 as stee
15-
from substrait.extension_registry import ExtensionRegistry
1615
from substrait.builders.extended_expression import (
1716
ExtendedExpressionOrUnbound,
1817
resolve_expression,
1918
)
19+
from substrait.extension_registry import ExtensionRegistry
2020
from substrait.type_inference import infer_plan_schema
2121
from substrait.utils import (
2222
merge_extension_declarations,
@@ -379,3 +379,33 @@ def resolve(registry: ExtensionRegistry) -> stp.Plan:
379379
)
380380

381381
return resolve
382+
383+
384+
def write_table(
385+
table_names: Union[str, Iterable[str]],
386+
input: PlanOrUnbound,
387+
create_mode: Union[stalg.WriteRel.CreateMode.ValueType, None] = None,
388+
) -> UnboundPlan:
389+
def resolve(registry: ExtensionRegistry) -> stp.Plan:
390+
bound_input = input if isinstance(input, stp.Plan) else input(registry)
391+
ns = infer_plan_schema(bound_input)
392+
_table_names = [table_names] if isinstance(table_names, str) else table_names
393+
_create_mode = create_mode or stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS
394+
395+
write_rel = stalg.Rel(
396+
write=stalg.WriteRel(
397+
input=bound_input.relations[-1].root.input,
398+
table_schema=ns,
399+
op=stalg.WriteRel.WRITE_OP_CTAS,
400+
create_mode=_create_mode,
401+
named_table=stalg.NamedObjectWrite(names=_table_names),
402+
)
403+
)
404+
return stp.Plan(
405+
relations=[
406+
stp.PlanRel(root=stalg.RelRoot(input=write_rel, names=ns.names))
407+
],
408+
**_merge_extensions(bound_input),
409+
)
410+
411+
return resolve

tests/builders/plan/test_write.py

Lines changed: 14 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,41 @@
1-
import substrait.gen.proto.type_pb2 as stt
2-
import substrait.gen.proto.plan_pb2 as stp
31
import substrait.gen.proto.algebra_pb2 as stalg
2+
import substrait.gen.proto.plan_pb2 as stp
3+
import substrait.gen.proto.type_pb2 as stt
4+
from substrait.builders.plan import read_named_table, write_table
45
from substrait.builders.type import boolean, i64
5-
from substrait.builders.plan import read_named_table
66

77
struct = stt.Type.Struct(types=[i64(nullable=False), boolean()])
88

99
named_struct = stt.NamedStruct(names=["id", "is_applicable"], struct=struct)
1010

1111

1212
def test_write_rel():
13-
actual = read_named_table("example_table", named_struct)(None)
13+
actual = write_table(
14+
"example_table_write_test",
15+
read_named_table("example_table", named_struct),
16+
)(None)
1417

15-
# write example table test
16-
stp.Plan(
18+
expected = stp.Plan(
1719
relations=[
1820
stp.PlanRel(
1921
root=stalg.RelRoot(
2022
input=stalg.Rel(
2123
write=stalg.WriteRel(
2224
input=stalg.Rel(
2325
read=stalg.ReadRel(
24-
common=stalg.RelCommon(direct=stalg.RelCommon.Direct()),
26+
common=stalg.RelCommon(
27+
direct=stalg.RelCommon.Direct()
28+
),
2529
base_schema=named_struct,
2630
named_table=stalg.ReadRel.NamedTable(
2731
names=["example_table"]
2832
),
2933
)
3034
),
31-
common=stalg.RelCommon(direct=stalg.RelCommon.Direct()),
35+
op=stalg.WriteRel.WRITE_OP_CTAS,
3236
table_schema=named_struct,
33-
create_mode=stalg.WriteRel.CreateMode.CREATE_MODE_REPLACE_IF_EXISTS,
34-
named_table=stalg.NamedTable(
35-
names=["example_table_write_test"]
36-
),
37-
)
38-
),
39-
names=["id", "is_applicable"],
40-
)
41-
)
42-
]
43-
)
44-
45-
# read back the table
46-
expected = stp.Plan(
47-
relations=[
48-
stp.PlanRel(
49-
root=stalg.RelRoot(
50-
input=stalg.Rel(
51-
read=stalg.ReadRel(
52-
common=stalg.RelCommon(direct=stalg.RelCommon.Direct()),
53-
base_schema=named_struct,
54-
named_table=stalg.ReadRel.NamedTable(
37+
create_mode=stalg.WriteRel.CREATE_MODE_ERROR_IF_EXISTS,
38+
named_table=stalg.NamedObjectWrite(
5539
names=["example_table_write_test"]
5640
),
5741
)
@@ -61,30 +45,4 @@ def test_write_rel():
6145
)
6246
]
6347
)
64-
65-
assert actual == expected
66-
67-
68-
def test_write_rel_db():
69-
actual = read_named_table(["example_db", "example_table"], named_struct)(None)
70-
71-
expected = stp.Plan(
72-
relations=[
73-
stp.PlanRel(
74-
root=stalg.RelRoot(
75-
input=stalg.Rel(
76-
read=stalg.ReadRel(
77-
common=stalg.RelCommon(direct=stalg.RelCommon.Direct()),
78-
base_schema=named_struct,
79-
named_table=stalg.ReadRel.NamedTable(
80-
names=["example_db", "example_table"]
81-
),
82-
)
83-
),
84-
names=["id", "is_applicable"],
85-
)
86-
)
87-
]
88-
)
89-
9048
assert actual == expected

0 commit comments

Comments
 (0)