Skip to content

Commit e7ad07e

Browse files
jdrakens-circle-ci
andauthored
LEXIO-38100 Refactor ID setting algorithm (#11)
* refactor id algo * Version bumped to 0.11.0 Co-authored-by: ns-circle-ci <devops-team+circleci@narrativescience.com>
1 parent 43b49ac commit e7ad07e

File tree

5 files changed

+78
-38
lines changed

5 files changed

+78
-38
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pysaql"
3-
version = "0.10.0"
3+
version = "0.11.0"
44
description = "Python SAQL query builder"
55
authors = ["Jonathan Drake <jon.drake@salesforce.com>"]
66
license = "BSD-3-Clause"

pysaql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Python SAQL query builder"""
22

3-
__version__ = "0.10.0"
3+
__version__ = "0.11.0"

pysaql/stream.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .enums import FillDateTypeString, JoinType, Order
1111
from .scalar import BinaryOperation, field, Scalar
12-
from .util import stringify, stringify_list
12+
from .util import flatten, stringify, stringify_list
1313

1414
__ALL__ = ["load", "cogroup"]
1515

@@ -22,6 +22,15 @@ class StreamStatement(ABC):
2222

2323
stream: "Stream"
2424

25+
def get_streams(self) -> list[Stream]:
26+
"""Get a flat list of streams nested within this stream statement
27+
28+
Returns:
29+
list of streams
30+
31+
"""
32+
return []
33+
2534

2635
class Stream:
2736
"""Base class for a SAQL data stream"""
@@ -44,39 +53,14 @@ def ref(self) -> str:
4453
"""Stream reference in the SAQL query"""
4554
return f"q{self._id}"
4655

47-
def increment_id(self, incr: int) -> int:
48-
"""Increment the stream ID
49-
50-
This should not be called by clients.
51-
52-
Args:
53-
incr: Value to increment
56+
def get_streams(self) -> list[Stream]:
57+
"""Get a flat list of streams nested within this stream
5458
5559
Returns:
56-
new stream ID
60+
list of streams
5761
5862
"""
59-
max_id = 0
60-
i = 0
61-
for statement in self._statements:
62-
if isinstance(statement, LoadStatement):
63-
statement.stream._id += incr + i
64-
max_id = max(max_id, statement.stream._id)
65-
i += 1
66-
elif isinstance(statement, (CogroupStatement, UnionStatement)):
67-
# For cogroup and union statements, leave the left-most (first) branch alone
68-
if isinstance(statement, CogroupStatement):
69-
streams = [stream for (stream, _) in statement.streams[1:]]
70-
else:
71-
streams = list(statement.streams[1:])
72-
73-
for stream in streams:
74-
stream.increment_id(incr + i)
75-
max_id = max(max_id, stream._id)
76-
i += 1
77-
78-
self._id = max_id + 1
79-
return self._id
63+
return flatten([s.get_streams() for s in self._statements])
8064

8165
def add_statement(self, statement: StreamStatement) -> None:
8266
"""Add a statement to the stream
@@ -86,6 +70,9 @@ def add_statement(self, statement: StreamStatement) -> None:
8670
8771
"""
8872
self._statements.append(statement)
73+
# Update all stream IDs
74+
for i, s in enumerate(flatten(statement.get_streams())):
75+
s._id = i
8976

9077
def field(self, name: str) -> field:
9178
"""Create a new field object scoped to this stream
@@ -215,6 +202,15 @@ def __str__(self) -> str:
215202
"""Cast this load statement to a string"""
216203
return f'{self.stream.ref} = load "{self.name}";'
217204

205+
def get_streams(self) -> list[Stream]:
206+
"""Get a flat list of streams nested within this stream statement
207+
208+
Returns:
209+
list of streams
210+
211+
"""
212+
return [self.stream]
213+
218214

219215
class ProjectionStatement(StreamStatement):
220216
"""Statement to project columns from a stream"""
@@ -402,6 +398,20 @@ def __str__(self) -> str:
402398

403399
return "\n".join(lines)
404400

401+
def get_streams(self) -> list[Stream]:
402+
"""Get a flat list of streams nested within this stream statement
403+
404+
Returns:
405+
list of streams
406+
407+
"""
408+
return flatten(
409+
[
410+
[stream.get_streams() for (stream, _) in self.streams],
411+
[self.stream],
412+
]
413+
)
414+
405415

406416
class UnionStatement(StreamStatement):
407417
"""Statement to combine (union) two or more streams with the same structure into one"""
@@ -436,6 +446,15 @@ def __str__(self) -> str:
436446
lines.append(f"{self.stream.ref} = union {', '.join(stream_refs)};")
437447
return "\n".join(lines)
438448

449+
def get_streams(self) -> list[Stream]:
450+
"""Get a flat list of streams nested within this stream statement
451+
452+
Returns:
453+
list of streams
454+
455+
"""
456+
return flatten([[s.get_streams() for s in self.streams], [self.stream]])
457+
439458

440459
class FillStatement(StreamStatement):
441460
"""Statement to fill a data stream with missing dates"""
@@ -506,9 +525,6 @@ def cogroup(
506525
"""
507526
stream = Stream()
508527
stream.add_statement(CogroupStatement(stream, streams, join_type))
509-
# Increment stream IDs for all streams contained in this cogroup statement.
510-
# We'll use the ID of the first stream as the basis for incrementing.
511-
stream.increment_id(streams[0][0]._id)
512528
return stream
513529

514530

@@ -527,7 +543,4 @@ def union(*streams: Stream) -> Stream:
527543
"""
528544
stream = Stream()
529545
stream.add_statement(UnionStatement(stream, streams))
530-
# Increment stream IDs for all streams contained in this union statement.
531-
# We'll use the ID of the first stream as the basis for incrementing.
532-
stream.increment_id(streams[0]._id)
533546
return stream

pysaql/util.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,20 @@ def stringify_list(seq: Sequence) -> str:
6565
"""
6666
seq = [seq] if not isinstance(seq, (list, tuple, set)) else seq
6767
return f"({', '.join(str(s) for s in seq)})" if len(seq) > 1 else str(seq[0])
68+
69+
70+
def flatten(seq: list) -> list:
71+
"""Recursively flatten a list
72+
73+
Args:
74+
seq: Sequence of items
75+
76+
Returns:
77+
flatten list of items
78+
79+
"""
80+
if not seq:
81+
return seq
82+
if isinstance(seq[0], list):
83+
return flatten(seq[0]) + flatten(seq[1:])
84+
return seq[:1] + flatten(seq[1:])

tests/unit/test_util.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,13 @@ def test_stringify_list__one():
5454
def test_stringify_list__multiple():
5555
"""Should stringify a list with one item"""
5656
assert mod_ut.stringify_list(["foo", "bar"]) == "(foo, bar)"
57+
58+
59+
def test_flatten__empty():
60+
"""Should return empty list"""
61+
assert mod_ut.flatten([]) == []
62+
63+
64+
def test_flatten__nested():
65+
"""Should flatten nested list"""
66+
assert mod_ut.flatten([1, [2, [3, [4, 5]], 6], 7]) == [1, 2, 3, 4, 5, 6, 7]

0 commit comments

Comments
 (0)