Skip to content

Commit 55e4a71

Browse files
committed
Refactor
1 parent a426ed3 commit 55e4a71

File tree

4 files changed

+56
-26
lines changed

4 files changed

+56
-26
lines changed

dbldatagen/config.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
This module implements configuration classes for writing generated data.
77
"""
8-
from dataclasses import dataclass, field
8+
from dataclasses import dataclass
99

1010

1111
@dataclass(frozen=True, slots=True)
@@ -23,5 +23,14 @@ class OutputDataset:
2323
location: str
2424
output_mode: str = "append"
2525
format: str = "delta"
26-
options: dict[str, str] = field(default_factory=dict)
27-
trigger: dict[str, bool | str] = field(default_factory=dict)
26+
options: dict[str, str] | None = None
27+
trigger: dict[str, str] | None = None
28+
29+
def __post_init__(self) -> None:
30+
if not self.trigger:
31+
return
32+
33+
# Only processingTime is currently supported
34+
if "processingTime" not in self.trigger:
35+
valid_trigger_format = '{"processingTime": "10 SECONDS"}'
36+
raise ValueError(f"Attribute 'trigger' must be a dictionary of the form '{valid_trigger_format}'")

dbldatagen/data_generator.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Any
1616

1717
from pyspark.sql import DataFrame, SparkSession
18+
from pyspark.sql.streaming.query import StreamingQuery
1819
from pyspark.sql.types import DataType, IntegerType, LongType, StringType, StructField, StructType
1920

2021
from dbldatagen import datagen_constants
@@ -1917,17 +1918,24 @@ def scriptMerge(
19171918
return result
19181919

19191920
def buildOutputDataset(
1920-
self, output_dataset: OutputDataset, generator_options: dict[str, Any] | None = None
1921-
) -> None:
1921+
self, output_dataset: OutputDataset,
1922+
with_streaming: bool | None = None,
1923+
generator_options: dict[str, Any] | None = None
1924+
) -> StreamingQuery | None:
19221925
"""
19231926
Builds a `DataFrame` from the `DataGenerator` and writes the data to a target table.
19241927
19251928
:param output_dataset: Output configuration for writing generated data
1929+
:param with_streaming: Whether to generate data using streaming. If None, auto-detects based on trigger
19261930
:param generator_options: Options for building the generator (e.g. `{"rowsPerSecond": 100}`)
1931+
:returns: A Spark `StreamingQuery` if data is written in streaming, otherwise `None`
19271932
"""
1928-
with_streaming = output_dataset.trigger is not None
1933+
# Auto-detect streaming mode if not explicitly specified
1934+
if with_streaming is None:
1935+
with_streaming = output_dataset.trigger is not None and len(output_dataset.trigger) > 0
1936+
19291937
df = self.build(withStreaming=with_streaming, options=generator_options)
1930-
write_data_to_output(df, config=output_dataset)
1938+
return write_data_to_output(df, config=output_dataset)
19311939

19321940
@staticmethod
19331941
def loadFromJson(options: str) -> "DataGenerator":

dbldatagen/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import jmespath
2121
from pyspark.sql import DataFrame
22+
from pyspark.sql.streaming.query import StreamingQuery
2223

2324
from dbldatagen.config import OutputDataset
2425

@@ -365,12 +366,13 @@ def system_time_millis() -> int:
365366
return curr_time
366367

367368

368-
def write_data_to_output(df: DataFrame, config: OutputDataset) -> None:
369+
def write_data_to_output(df: DataFrame, config: OutputDataset) -> StreamingQuery | None:
369370
"""
370371
Writes a DataFrame to the sink configured in the output configuration.
371372
372373
:param df: Spark DataFrame to write
373374
:param config: Output configuration passed as an `OutputConfig`
375+
:returns: A Spark `StreamingQuery` if data is written in streaming, otherwise `None`
374376
"""
375377
if df.isStreaming:
376378
if not config.trigger:
@@ -388,11 +390,14 @@ def write_data_to_output(df: DataFrame, config: OutputDataset) -> None:
388390
.trigger(**config.trigger)
389391
.start(config.location)
390392
)
391-
query.awaitTermination()
393+
return query
394+
392395
else:
393396
(
394397
df.write.format(config.format)
395398
.mode(config.output_mode)
396399
.options(**config.options)
397400
.save(config.location)
398401
)
402+
403+
return None

tests/test_output.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import shutil
3+
import time
34
import uuid
45
import pytest
56

@@ -27,14 +28,15 @@ def get_output_directories(self):
2728
shutil.rmtree(base_dir, ignore_errors=True)
2829
print(f"\n\n*** test dir [{base_dir}] deleted")
2930

30-
@pytest.mark.parametrize("seed_column_name, table_format, table_location", [
31-
("id", "delta", "/table_folder"),
32-
("_id", "json", "/json_data_folder"),
33-
("id", "csv", "/csv_data_folder"),
34-
])
35-
def test_build_output_data_batch(self, get_output_directories, seed_column_name, table_format, table_location):
31+
@pytest.mark.parametrize("trigger", [{"availableNow": True}, {"once": True}, {"invalid": "yes"}])
32+
def test_initialize_output_dataset_invalid_trigger(self, trigger):
33+
with pytest.raises(ValueError, match=f"Attribute 'trigger' must be a dictionary of the form"):
34+
_ = dg.OutputDataset(location="/location", trigger=trigger)
35+
36+
@pytest.mark.parametrize("seed_column_name, table_format", [("id", "parquet"), ("_id", "json"), ("id", "csv")])
37+
def test_build_output_data_batch(self, get_output_directories, seed_column_name, table_format):
3638
base_dir, data_dir, checkpoint_dir = get_output_directories
37-
table_dir = f"{data_dir}/{table_location}"
39+
table_dir = f"{data_dir}/{uuid.uuid4()}"
3840

3941
gen = dg.DataGenerator(
4042
sparkSession=spark,
@@ -59,21 +61,17 @@ def test_build_output_data_batch(self, get_output_directories, seed_column_name,
5961
location=table_dir,
6062
output_mode="append",
6163
format=table_format,
62-
options={"mergeSchema": "true", "checkpointLocation": f"{data_dir}/{checkpoint_dir}"},
64+
options={"mergeSchema": "true"},
6365
)
6466

6567
gen.buildOutputDataset(output_dataset)
6668
persisted_df = spark.read.format(table_format).load(table_dir)
6769
assert persisted_df.count() > 0
6870

69-
@pytest.mark.parametrize("seed_column_name, table_format, table_location", [
70-
("id", "delta", "/table_folder"),
71-
("_id", "json", "/json_data_folder"),
72-
("id", "csv", "/csv_data_folder"),
73-
])
74-
def test_build_output_data_streaming(self, get_output_directories, seed_column_name, table_format, table_location):
71+
@pytest.mark.parametrize("seed_column_name, table_format", [("id", "parquet"), ("_id", "json"), ("id", "csv")])
72+
def test_build_output_data_streaming(self, get_output_directories, seed_column_name, table_format):
7573
base_dir, data_dir, checkpoint_dir = get_output_directories
76-
table_dir = f"{data_dir}/{table_location}"
74+
table_dir = f"{data_dir}/{uuid.uuid4()}"
7775

7876
gen = dg.DataGenerator(
7977
sparkSession=spark,
@@ -99,9 +97,19 @@ def test_build_output_data_streaming(self, get_output_directories, seed_column_n
9997
output_mode="append",
10098
format=table_format,
10199
options={"mergeSchema": "true", "checkpointLocation": f"{data_dir}/{checkpoint_dir}"},
102-
trigger={"availableNow": True}
100+
trigger={"processingTime": "1 SECOND"}
103101
)
104102

105-
gen.buildOutputDataset(output_dataset)
103+
query = gen.buildOutputDataset(output_dataset, with_streaming=True)
104+
105+
start_time = time.time()
106+
elapsed_time = 0
107+
time_limit = 10.0
108+
109+
while elapsed_time < time_limit:
110+
time.sleep(1)
111+
elapsed_time = time.time() - start_time
112+
113+
query.stop()
106114
persisted_df = spark.read.format(table_format).load(table_dir)
107115
assert persisted_df.count() > 0

0 commit comments

Comments
 (0)