11"""Contains stream expressions and statements"""
22
3+ from __future__ import annotations
4+
35from abc import ABC
46import functools
57import operator
68from typing import List , Optional , Sequence , Tuple , Union
79
8- from typing_extensions import Self
9-
1010from .enums import FillDateTypeString , JoinType , Order
1111from .expression import Expression
1212from .scalar import BinaryOperation , field , Scalar
@@ -83,7 +83,7 @@ def add_statement(self, statement: StreamStatement) -> None:
8383 """
8484 self ._statements .append (statement )
8585
86- def foreach (self , * fields : Scalar ) -> Self :
86+ def foreach (self , * fields : Scalar ) -> Stream :
8787 """Applies a set of expressions to every row in a dataset.
8888
8989 This action is often referred to as projection
@@ -98,14 +98,15 @@ def foreach(self, *fields: Scalar) -> Self:
9898 self ._statements .append (ProjectionStatement (self , fields ))
9999 return self
100100
101- def group (self , * fields : Scalar ) -> Self :
101+ def group (self , * fields : Scalar ) -> Stream :
102102 """Organizes the rows returned from a query into groups
103103
104104 Within each group, you can apply an aggregate function, such as count() or sum()
105105 to get the number of items or sum, respectively.
106106
107107 Args:
108- fields: One or more fields to group by
108+ fields: One or more fields to group by. If no fields are provided,
109+ "group by all" is assumed.
109110
110111 Returns:
111112 self
@@ -114,7 +115,7 @@ def group(self, *fields: Scalar) -> Self:
114115 self ._statements .append (GroupStatement (self , fields ))
115116 return self
116117
117- def filter (self , * filters : BinaryOperation ) -> Self :
118+ def filter (self , * filters : BinaryOperation ) -> Stream :
118119 """Selects rows from a dataset based on a filter predicate
119120
120121 Args:
@@ -128,7 +129,7 @@ def filter(self, *filters: BinaryOperation) -> Self:
128129 self ._statements .append (FilterStatement (self , filters ))
129130 return self
130131
131- def order (self , * fields : Union [Scalar , Tuple [Scalar , Order ]]) -> Self :
132+ def order (self , * fields : Union [Scalar , Tuple [Scalar , Order ]]) -> Stream :
132133 """Sorts in ascending or descending order on one or more fields.
133134
134135 Args:
@@ -141,7 +142,7 @@ def order(self, *fields: Union[Scalar, Tuple[Scalar, Order]]) -> Self:
141142 self ._statements .append (OrderStatement (self , fields ))
142143 return self
143144
144- def limit (self , limit : int ) -> Self :
145+ def limit (self , limit : int ) -> Stream :
145146 """Limits the number of rows returned.
146147
147148 Args:
@@ -159,7 +160,7 @@ def fill(
159160 date_cols : Sequence [field ],
160161 date_type_string : FillDateTypeString ,
161162 partition : Optional [field ] = None ,
162- ) -> Self :
163+ ) -> Stream :
163164 """Fills missing date values by adding rows in data stream
164165
165166 Args:
@@ -202,7 +203,7 @@ def __str__(self) -> str:
202203class ProjectionStatement (StreamStatement ):
203204 """Statement to project columns from a stream"""
204205
205- def __init__ (self , stream : Stream , fields : List [Scalar ]) -> None :
206+ def __init__ (self , stream : Stream , fields : Sequence [Scalar ]) -> None :
206207 """Initializer
207208
208209 Args:
@@ -228,7 +229,7 @@ class OrderStatement(StreamStatement):
228229 def __init__ (
229230 self ,
230231 stream : Stream ,
231- fields : Union [Scalar , List [ Scalar ], List [ Tuple [Scalar , Order ]]],
232+ fields : Sequence [ Union [Scalar , Tuple [Scalar , Order ]]],
232233 ) -> None :
233234 """Initializer
234235
@@ -284,30 +285,29 @@ def __str__(self) -> str:
284285class GroupStatement (StreamStatement ):
285286 """Statement to group rows in a stream"""
286287
287- def __init__ (self , stream : Stream , fields : List [Scalar ]):
288+ def __init__ (self , stream : Stream , fields : Sequence [Scalar ]):
288289 """Initializer
289290
290291 Args:
291292 stream: Stream containing this statement
292- fields: One or more fields to group by
293+ fields: One or more fields to group by. If no fields are provided,
294+ "group by all" is assumed.
293295
294296 """
295297 super ().__init__ ()
296298 self .stream = stream
297- if not fields :
298- raise ValueError ("At least one field is required" )
299299 self .fields = fields
300300
301301 def __str__ (self ) -> str :
302302 """Cast this group statement to a string"""
303- fields = ", " . join ( str ( f ) for f in self .fields )
303+ fields = stringify_list ( self . fields ) if self .fields else "all"
304304 return f"{ self .stream .ref } = group { self .stream .ref } by { fields } ;"
305305
306306
307307class FilterStatement (StreamStatement ):
308308 """Statement to filter rows in a stream"""
309309
310- def __init__ (self , stream : Stream , filters : List [BinaryOperation ]) -> None :
310+ def __init__ (self , stream : Stream , filters : Sequence [BinaryOperation ]) -> None :
311311 """Initializer
312312
313313 Args:
@@ -336,7 +336,7 @@ class CogroupStatement(StreamStatement):
336336 def __init__ (
337337 self ,
338338 stream : Stream ,
339- streams : List [Tuple [Stream , Scalar ]],
339+ streams : Sequence [Tuple [Stream , Scalar ]],
340340 join_type : JoinType = JoinType .inner ,
341341 ) -> None :
342342 """Initializer
0 commit comments