1- from datetime import date
21import itertools
2+ from datetime import date
3+ from typing import Any , Callable , Iterable , Union
4+
35import substrait .gen .proto .algebra_pb2 as stalg
4- import substrait .gen .proto .type_pb2 as stp
56import substrait .gen .proto .extended_expression_pb2 as stee
67import substrait .gen .proto .extensions .extensions_pb2 as ste
8+ import substrait .gen .proto .type_pb2 as stp
79from substrait .extension_registry import ExtensionRegistry
10+ from substrait .type_inference import infer_extended_expression_schema
811from substrait .utils import (
9- type_num_names ,
10- merge_extension_urns ,
11- merge_extension_uris ,
1212 merge_extension_declarations ,
13+ merge_extension_uris ,
14+ merge_extension_urns ,
15+ type_num_names ,
1316)
14- from substrait .type_inference import infer_extended_expression_schema
15- from typing import Callable , Any , Union , Iterable
1617
1718UnboundExtendedExpression = Callable [
1819 [stp .NamedStruct , ExtensionRegistry ], stee .ExtendedExpression
2122
2223
2324def _alias_or_inferred (
24- alias : Union [Iterable [str ], str ,None ],
25+ alias : Union [Iterable [str ], str , None ],
2526 op : str ,
2627 args : Iterable [str ],
2728):
@@ -44,7 +45,7 @@ def resolve_expression(
4445
4546
4647def literal (
47- value : Any , type : stp .Type , alias : Union [Iterable [str ], str ,None ] = None
48+ value : Any , type : stp .Type , alias : Union [Iterable [str ], str , None ] = None
4849) -> UnboundExtendedExpression :
4950 """Builds a resolver for ExtendedExpression containing a literal expression"""
5051
@@ -154,7 +155,7 @@ def resolve(
154155 return resolve
155156
156157
157- def column (field : Union [str , int ], alias : Union [Iterable [str ], str ,None ] = None ):
158+ def column (field : Union [str , int ], alias : Union [Iterable [str ], str , None ] = None ):
158159 """Builds a resolver for ExtendedExpression containing a FieldReference expression
159160
160161 Accepts either an index or a field name of a desired field.
@@ -208,7 +209,7 @@ def scalar_function(
208209 urn : str ,
209210 function : str ,
210211 expressions : Iterable [ExtendedExpressionOrUnbound ],
211- alias : Union [Iterable [str ], str ,None ] = None ,
212+ alias : Union [Iterable [str ], str , None ] = None ,
212213):
213214 """Builds a resolver for ExtendedExpression containing a ScalarFunction expression"""
214215
@@ -306,7 +307,7 @@ def aggregate_function(
306307 urn : str ,
307308 function : str ,
308309 expressions : Iterable [ExtendedExpressionOrUnbound ],
309- alias : Union [Iterable [str ], str ,None ] = None ,
310+ alias : Union [Iterable [str ], str , None ] = None ,
310311):
311312 """Builds a resolver for ExtendedExpression containing a AggregateFunction measure"""
312313
@@ -402,7 +403,7 @@ def window_function(
402403 function : str ,
403404 expressions : Iterable [ExtendedExpressionOrUnbound ],
404405 partitions : Iterable [ExtendedExpressionOrUnbound ] = [],
405- alias : Union [Iterable [str ], str ,None ] = None ,
406+ alias : Union [Iterable [str ], str , None ] = None ,
406407):
407408 """Builds a resolver for ExtendedExpression containing a WindowFunction expression"""
408409
@@ -512,7 +513,7 @@ def resolve(
512513def if_then (
513514 ifs : Iterable [tuple [ExtendedExpressionOrUnbound , ExtendedExpressionOrUnbound ]],
514515 _else : ExtendedExpressionOrUnbound ,
515- alias : Union [Iterable [str ], str ,None ] = None ,
516+ alias : Union [Iterable [str ], str , None ] = None ,
516517):
517518 """Builds a resolver for ExtendedExpression containing an IfThen expression"""
518519
@@ -551,24 +552,16 @@ def resolve(
551552 referred_expr = [
552553 stee .ExpressionReference (
553554 expression = stalg .Expression (
554- if_then = stalg .Expression .IfThen (
555- ** {
556- "ifs" : [
557- stalg .Expression .IfThen .IfClause (
558- ** {
559- "if" : if_clause [0 ]
560- .referred_expr [0 ]
561- .expression ,
562- "then" : if_clause [1 ]
563- .referred_expr [0 ]
564- .expression ,
565- }
566- )
567- for if_clause in bound_ifs
568- ],
569- "else" : bound_else .referred_expr [0 ].expression ,
570- }
571- )
555+ if_then = stalg .Expression .IfThen (** {
556+ "ifs" : [
557+ stalg .Expression .IfThen .IfClause (** {
558+ "if" : if_clause [0 ].referred_expr [0 ].expression ,
559+ "then" : if_clause [1 ].referred_expr [0 ].expression ,
560+ })
561+ for if_clause in bound_ifs
562+ ],
563+ "else" : bound_else .referred_expr [0 ].expression ,
564+ })
572565 ),
573566 output_names = _alias_or_inferred (
574567 alias ,
@@ -639,12 +632,10 @@ def resolve(
639632 switch_expression = stalg .Expression .SwitchExpression (
640633 match = bound_match .referred_expr [0 ].expression ,
641634 ifs = [
642- stalg .Expression .SwitchExpression .IfValue (
643- ** {
644- "if" : i .referred_expr [0 ].expression .literal ,
645- "then" : t .referred_expr [0 ].expression ,
646- }
647- )
635+ stalg .Expression .SwitchExpression .IfValue (** {
636+ "if" : i .referred_expr [0 ].expression .literal ,
637+ "then" : t .referred_expr [0 ].expression ,
638+ })
648639 for i , t in bound_ifs
649640 ],
650641 ** {"else" : bound_else .referred_expr [0 ].expression },
@@ -767,7 +758,11 @@ def resolve(
767758 return resolve
768759
769760
770- def cast (input : ExtendedExpressionOrUnbound , type : stp .Type ,alias : Union [Iterable [str ], str ,None ] = None ):
761+ def cast (
762+ input : ExtendedExpressionOrUnbound ,
763+ type : stp .Type ,
764+ alias : Union [Iterable [str ], str , None ] = None ,
765+ ):
771766 """Builds a resolver for ExtendedExpression containing a cast expression"""
772767
773768 def resolve (
@@ -785,8 +780,9 @@ def resolve(
785780 failure_behavior = stalg .Expression .Cast .FAILURE_BEHAVIOR_RETURN_NULL ,
786781 )
787782 ),
788- output_names = _alias_or_inferred (alias , "cast" , []),
789-
783+ output_names = _alias_or_inferred (
784+ alias , "cast" , [bound_input .referred_expr [0 ].output_names [0 ]]
785+ ),
790786 )
791787 ],
792788 base_schema = base_schema ,
0 commit comments