99
1010from .enums import FillDateTypeString , JoinType , Order
1111from .scalar import BinaryOperation , field , Scalar
12- from .util import stringify , stringify_list
12+ from .util import flatten , stringify , stringify_list
1313
1414__ALL__ = ["load" , "cogroup" ]
1515
@@ -22,6 +22,15 @@ class StreamStatement(ABC):
2222
2323 stream : "Stream"
2424
25+ def get_streams (self ) -> list [Stream ]:
26+ """Get a flat list of streams nested within this stream statement
27+
28+ Returns:
29+ list of streams
30+
31+ """
32+ return []
33+
2534
2635class Stream :
2736 """Base class for a SAQL data stream"""
@@ -44,39 +53,14 @@ def ref(self) -> str:
4453 """Stream reference in the SAQL query"""
4554 return f"q{ self ._id } "
4655
47- def increment_id (self , incr : int ) -> int :
48- """Increment the stream ID
49-
50- This should not be called by clients.
51-
52- Args:
53- incr: Value to increment
56+ def get_streams (self ) -> list [Stream ]:
57+ """Get a flat list of streams nested within this stream
5458
5559 Returns:
56- new stream ID
60+ list of streams
5761
5862 """
59- max_id = 0
60- i = 0
61- for statement in self ._statements :
62- if isinstance (statement , LoadStatement ):
63- statement .stream ._id += incr + i
64- max_id = max (max_id , statement .stream ._id )
65- i += 1
66- elif isinstance (statement , (CogroupStatement , UnionStatement )):
67- # For cogroup and union statements, leave the left-most (first) branch alone
68- if isinstance (statement , CogroupStatement ):
69- streams = [stream for (stream , _ ) in statement .streams [1 :]]
70- else :
71- streams = list (statement .streams [1 :])
72-
73- for stream in streams :
74- stream .increment_id (incr + i )
75- max_id = max (max_id , stream ._id )
76- i += 1
77-
78- self ._id = max_id + 1
79- return self ._id
63+ return flatten ([s .get_streams () for s in self ._statements ])
8064
8165 def add_statement (self , statement : StreamStatement ) -> None :
8266 """Add a statement to the stream
@@ -86,6 +70,9 @@ def add_statement(self, statement: StreamStatement) -> None:
8670
8771 """
8872 self ._statements .append (statement )
73+ # Update all stream IDs
74+ for i , s in enumerate (flatten (statement .get_streams ())):
75+ s ._id = i
8976
9077 def field (self , name : str ) -> field :
9178 """Create a new field object scoped to this stream
@@ -215,6 +202,15 @@ def __str__(self) -> str:
215202 """Cast this load statement to a string"""
216203 return f'{ self .stream .ref } = load "{ self .name } ";'
217204
205+ def get_streams (self ) -> list [Stream ]:
206+ """Get a flat list of streams nested within this stream statement
207+
208+ Returns:
209+ list of streams
210+
211+ """
212+ return [self .stream ]
213+
218214
219215class ProjectionStatement (StreamStatement ):
220216 """Statement to project columns from a stream"""
@@ -402,6 +398,20 @@ def __str__(self) -> str:
402398
403399 return "\n " .join (lines )
404400
401+ def get_streams (self ) -> list [Stream ]:
402+ """Get a flat list of streams nested within this stream statement
403+
404+ Returns:
405+ list of streams
406+
407+ """
408+ return flatten (
409+ [
410+ [stream .get_streams () for (stream , _ ) in self .streams ],
411+ [self .stream ],
412+ ]
413+ )
414+
405415
406416class UnionStatement (StreamStatement ):
407417 """Statement to combine (union) two or more streams with the same structure into one"""
@@ -436,6 +446,15 @@ def __str__(self) -> str:
436446 lines .append (f"{ self .stream .ref } = union { ', ' .join (stream_refs )} ;" )
437447 return "\n " .join (lines )
438448
449+ def get_streams (self ) -> list [Stream ]:
450+ """Get a flat list of streams nested within this stream statement
451+
452+ Returns:
453+ list of streams
454+
455+ """
456+ return flatten ([[s .get_streams () for s in self .streams ], [self .stream ]])
457+
439458
440459class FillStatement (StreamStatement ):
441460 """Statement to fill a data stream with missing dates"""
@@ -506,9 +525,6 @@ def cogroup(
506525 """
507526 stream = Stream ()
508527 stream .add_statement (CogroupStatement (stream , streams , join_type ))
509- # Increment stream IDs for all streams contained in this cogroup statement.
510- # We'll use the ID of the first stream as the basis for incrementing.
511- stream .increment_id (streams [0 ][0 ]._id )
512528 return stream
513529
514530
@@ -527,7 +543,4 @@ def union(*streams: Stream) -> Stream:
527543 """
528544 stream = Stream ()
529545 stream .add_statement (UnionStatement (stream , streams ))
530- # Increment stream IDs for all streams contained in this union statement.
531- # We'll use the ID of the first stream as the basis for incrementing.
532- stream .increment_id (streams [0 ]._id )
533546 return stream
0 commit comments