Skip to content

Commit 43b49ac

Browse files
LEXIO-38099: union operator (#10)
* LEXIO-38099: added UnionStatement * LEXIO-38099: added union tests * Version bumped to 0.10.0 * LEXIO-38099: cruft update Co-authored-by: ns-circle-ci <[email protected]>
1 parent 1a16ee2 commit 43b49ac

File tree

5 files changed

+101
-7
lines changed

5 files changed

+101
-7
lines changed

.cruft.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"template": "https://github.com/NarrativeScience/cookiecutter-python-lib",
3-
"commit": "e090078902c7effbcde4bb9da97dd41eefee28b7",
3+
"commit": "06d791b4e3ac2362c595a9bcf0617f84e546ec3c",
44
"checkout": null,
55
"context": {
66
"cookiecutter": {

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pysaql"
3-
version = "0.9.0"
3+
version = "0.10.0"
44
description = "Python SAQL query builder"
55
authors = ["Jonathan Drake <[email protected]>"]
66
license = "BSD-3-Clause"

pysaql/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Python SAQL query builder"""
22

3-
__version__ = "0.9.0"
3+
__version__ = "0.10.0"

pysaql/stream.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,14 @@ def increment_id(self, incr: int) -> int:
6363
statement.stream._id += incr + i
6464
max_id = max(max_id, statement.stream._id)
6565
i += 1
66-
elif isinstance(statement, CogroupStatement):
67-
# For cogroup statements, leave the left-most (first) branch alone
68-
for (stream, _) in statement.streams[1:]:
66+
elif isinstance(statement, (CogroupStatement, UnionStatement)):
67+
# For cogroup and union statements, leave the left-most (first) branch alone
68+
if isinstance(statement, CogroupStatement):
69+
streams = [stream for (stream, _) in statement.streams[1:]]
70+
else:
71+
streams = list(statement.streams[1:])
72+
73+
for stream in streams:
6974
stream.increment_id(incr + i)
7075
max_id = max(max_id, stream._id)
7176
i += 1
@@ -398,6 +403,40 @@ def __str__(self) -> str:
398403
return "\n".join(lines)
399404

400405

406+
class UnionStatement(StreamStatement):
407+
"""Statement to combine (union) two or more streams with the same structure into one"""
408+
409+
def __init__(
410+
self,
411+
stream: Stream,
412+
streams: Sequence[Stream],
413+
) -> None:
414+
"""Initializer
415+
416+
Args:
417+
stream: Stream containing this statement
418+
streams: Streams that will be combined
419+
420+
"""
421+
super().__init__()
422+
self.stream = stream
423+
if not streams or len(streams) < 2:
424+
raise ValueError("At least two streams are required")
425+
self.streams = streams
426+
427+
def __str__(self) -> str:
428+
"""Cast this union statement to a string"""
429+
lines = []
430+
stream_refs = []
431+
432+
for stream in self.streams:
433+
lines.append(str(stream))
434+
stream_refs.append(stream.ref)
435+
436+
lines.append(f"{self.stream.ref} = union {', '.join(stream_refs)};")
437+
return "\n".join(lines)
438+
439+
401440
class FillStatement(StreamStatement):
402441
"""Statement to fill a data stream with missing dates"""
403442

@@ -471,3 +510,24 @@ def cogroup(
471510
# We'll use the ID of the first stream as the basis for incrementing.
472511
stream.increment_id(streams[0][0]._id)
473512
return stream
513+
514+
515+
def union(*streams: Stream) -> Stream:
516+
"""Union data from two or more data streams into a single data stream
517+
518+
Each stream should have the same field names and structure. The streams do
519+
not need to be from the same dataset.
520+
521+
Args:
522+
streams: Streams that will be unioned together
523+
524+
Returns:
525+
a new stream
526+
527+
"""
528+
stream = Stream()
529+
stream.add_statement(UnionStatement(stream, streams))
530+
# Increment stream IDs for all streams contained in this union statement.
531+
# We'll use the ID of the first stream as the basis for incrementing.
532+
stream.increment_id(streams[0]._id)
533+
return stream

tests/unit/test_stream.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from pysaql.enums import FillDateTypeString, JoinType, Order
66
from pysaql.scalar import field
7-
from pysaql.stream import cogroup, load, Stream
7+
from pysaql.stream import cogroup, load, Stream, union
88

99

1010
def test_load():
@@ -200,3 +200,37 @@ def test_fill__partition():
200200
str(stream)
201201
== """q0 = fill q0 by (dateCols=('Year', 'Month', "Y-M"), partition='Type');"""
202202
)
203+
204+
205+
def test_union():
206+
"""Should return a unioned stream"""
207+
208+
q0 = load("q0_dataset")
209+
q1 = load("q1_dataset")
210+
q2 = load("q2_dataset")
211+
q3 = load("q3_dataset")
212+
213+
u0 = union(q0, q1)
214+
u1 = union(u0, q2, q3)
215+
216+
assert str(u1).split("\n") == [
217+
"""q0 = load "q0_dataset";""",
218+
"""q1 = load "q1_dataset";""",
219+
"""q2 = union q0, q1;""",
220+
"""q3 = load "q2_dataset";""",
221+
"""q4 = load "q3_dataset";""",
222+
"""q5 = union q2, q3, q4;""",
223+
]
224+
225+
226+
def test_union__no_streams():
227+
"""Should raise ValueError when no streams are provided"""
228+
with pytest.raises(ValueError):
229+
union()
230+
231+
232+
def test_union__one_streams():
233+
"""Should raise ValueError when a single streams is provided"""
234+
with pytest.raises(ValueError):
235+
q0 = load("q0_dataset")
236+
union(q0)

0 commit comments

Comments
 (0)