Skip to content

Commit 82054ac

Browse files
committed
fix: test for casting with infer_name
1 parent 003258e commit 82054ac

File tree

2 files changed

+41
-44
lines changed

2 files changed

+41
-44
lines changed

src/substrait/builders/extended_expression.py

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
from datetime import date
21
import itertools
2+
from datetime import date
3+
from typing import Any, Callable, Iterable, Union
4+
35
import substrait.gen.proto.algebra_pb2 as stalg
4-
import substrait.gen.proto.type_pb2 as stp
56
import substrait.gen.proto.extended_expression_pb2 as stee
67
import substrait.gen.proto.extensions.extensions_pb2 as ste
8+
import substrait.gen.proto.type_pb2 as stp
79
from substrait.extension_registry import ExtensionRegistry
10+
from substrait.type_inference import infer_extended_expression_schema
811
from substrait.utils import (
9-
type_num_names,
10-
merge_extension_urns,
11-
merge_extension_uris,
1212
merge_extension_declarations,
13+
merge_extension_uris,
14+
merge_extension_urns,
15+
type_num_names,
1316
)
14-
from substrait.type_inference import infer_extended_expression_schema
15-
from typing import Callable, Any, Union, Iterable
1617

1718
UnboundExtendedExpression = Callable[
1819
[stp.NamedStruct, ExtensionRegistry], stee.ExtendedExpression
@@ -21,7 +22,7 @@
2122

2223

2324
def _alias_or_inferred(
24-
alias: Union[Iterable[str], str,None],
25+
alias: Union[Iterable[str], str, None],
2526
op: str,
2627
args: Iterable[str],
2728
):
@@ -44,7 +45,7 @@ def resolve_expression(
4445

4546

4647
def literal(
47-
value: Any, type: stp.Type, alias: Union[Iterable[str], str,None] = None
48+
value: Any, type: stp.Type, alias: Union[Iterable[str], str, None] = None
4849
) -> UnboundExtendedExpression:
4950
"""Builds a resolver for ExtendedExpression containing a literal expression"""
5051

@@ -154,7 +155,7 @@ def resolve(
154155
return resolve
155156

156157

157-
def column(field: Union[str, int], alias: Union[Iterable[str], str,None] = None):
158+
def column(field: Union[str, int], alias: Union[Iterable[str], str, None] = None):
158159
"""Builds a resolver for ExtendedExpression containing a FieldReference expression
159160
160161
Accepts either an index or a field name of a desired field.
@@ -208,7 +209,7 @@ def scalar_function(
208209
urn: str,
209210
function: str,
210211
expressions: Iterable[ExtendedExpressionOrUnbound],
211-
alias: Union[Iterable[str], str,None] = None,
212+
alias: Union[Iterable[str], str, None] = None,
212213
):
213214
"""Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
214215

@@ -306,7 +307,7 @@ def aggregate_function(
306307
urn: str,
307308
function: str,
308309
expressions: Iterable[ExtendedExpressionOrUnbound],
309-
alias: Union[Iterable[str], str,None] = None,
310+
alias: Union[Iterable[str], str, None] = None,
310311
):
311312
"""Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
312313

@@ -402,7 +403,7 @@ def window_function(
402403
function: str,
403404
expressions: Iterable[ExtendedExpressionOrUnbound],
404405
partitions: Iterable[ExtendedExpressionOrUnbound] = [],
405-
alias: Union[Iterable[str], str,None] = None,
406+
alias: Union[Iterable[str], str, None] = None,
406407
):
407408
"""Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
408409

@@ -512,7 +513,7 @@ def resolve(
512513
def if_then(
513514
ifs: Iterable[tuple[ExtendedExpressionOrUnbound, ExtendedExpressionOrUnbound]],
514515
_else: ExtendedExpressionOrUnbound,
515-
alias: Union[Iterable[str], str,None] = None,
516+
alias: Union[Iterable[str], str, None] = None,
516517
):
517518
"""Builds a resolver for ExtendedExpression containing an IfThen expression"""
518519

@@ -551,24 +552,16 @@ def resolve(
551552
referred_expr=[
552553
stee.ExpressionReference(
553554
expression=stalg.Expression(
554-
if_then=stalg.Expression.IfThen(
555-
**{
556-
"ifs": [
557-
stalg.Expression.IfThen.IfClause(
558-
**{
559-
"if": if_clause[0]
560-
.referred_expr[0]
561-
.expression,
562-
"then": if_clause[1]
563-
.referred_expr[0]
564-
.expression,
565-
}
566-
)
567-
for if_clause in bound_ifs
568-
],
569-
"else": bound_else.referred_expr[0].expression,
570-
}
571-
)
555+
if_then=stalg.Expression.IfThen(**{
556+
"ifs": [
557+
stalg.Expression.IfThen.IfClause(**{
558+
"if": if_clause[0].referred_expr[0].expression,
559+
"then": if_clause[1].referred_expr[0].expression,
560+
})
561+
for if_clause in bound_ifs
562+
],
563+
"else": bound_else.referred_expr[0].expression,
564+
})
572565
),
573566
output_names=_alias_or_inferred(
574567
alias,
@@ -639,12 +632,10 @@ def resolve(
639632
switch_expression=stalg.Expression.SwitchExpression(
640633
match=bound_match.referred_expr[0].expression,
641634
ifs=[
642-
stalg.Expression.SwitchExpression.IfValue(
643-
**{
644-
"if": i.referred_expr[0].expression.literal,
645-
"then": t.referred_expr[0].expression,
646-
}
647-
)
635+
stalg.Expression.SwitchExpression.IfValue(**{
636+
"if": i.referred_expr[0].expression.literal,
637+
"then": t.referred_expr[0].expression,
638+
})
648639
for i, t in bound_ifs
649640
],
650641
**{"else": bound_else.referred_expr[0].expression},
@@ -767,7 +758,11 @@ def resolve(
767758
return resolve
768759

769760

770-
def cast(input: ExtendedExpressionOrUnbound, type: stp.Type,alias: Union[Iterable[str], str,None] = None):
761+
def cast(
762+
input: ExtendedExpressionOrUnbound,
763+
type: stp.Type,
764+
alias: Union[Iterable[str], str, None] = None,
765+
):
771766
"""Builds a resolver for ExtendedExpression containing a cast expression"""
772767

773768
def resolve(
@@ -785,8 +780,9 @@ def resolve(
785780
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL,
786781
)
787782
),
788-
output_names=_alias_or_inferred(alias, "cast", []),
789-
783+
output_names=_alias_or_inferred(
784+
alias, "cast", [bound_input.referred_expr[0].output_names[0]]
785+
),
790786
)
791787
],
792788
base_schema=base_schema,

tests/builders/extended_expression/test_cast.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import substrait.gen.proto.algebra_pb2 as stalg
2-
import substrait.gen.proto.type_pb2 as stt
32
import substrait.gen.proto.extended_expression_pb2 as stee
3+
import substrait.gen.proto.type_pb2 as stt
44
from substrait.builders.extended_expression import cast, literal
55
from substrait.builders.type import i8, i16
66
from substrait.extension_registry import ExtensionRegistry
@@ -37,7 +37,7 @@ def test_cast():
3737
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL,
3838
)
3939
),
40-
output_names=["cast"],
40+
output_names=["cast(Literal(3))"],
4141
)
4242
],
4343
base_schema=named_struct,
@@ -48,6 +48,7 @@ def test_cast():
4848

4949
def test_cast_with_extension():
5050
import yaml
51+
5152
import substrait.gen.proto.extensions.extensions_pb2 as ste
5253
from substrait.builders.extended_expression import scalar_function
5354

@@ -134,7 +135,7 @@ def test_cast_with_extension():
134135
failure_behavior=stalg.Expression.Cast.FAILURE_BEHAVIOR_RETURN_NULL,
135136
)
136137
),
137-
output_names=["cast"],
138+
output_names=["cast(add(Literal(1),Literal(2)))"],
138139
)
139140
],
140141
base_schema=named_struct,

0 commit comments

Comments
 (0)