|
1 | | -from pyspark.sql import SparkSession |
| 1 | +from typing import Any |
2 | 2 |
|
3 | | -from dbldatagen.data_generator import DataGenerator |
| 3 | +import pyspark.sql.functions as F |
| 4 | +from pyspark.sql import DataFrame, SparkSession |
4 | 5 |
|
5 | | -from .dataset_provider import DatasetProvider, dataset_definition |
| 6 | +import dbldatagen as dg |
| 7 | +from dbldatagen.data_generator import DataGenerator |
| 8 | +from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition |
6 | 9 |
|
7 | 10 |
|
8 | 11 | @dataset_definition(name="multi_table/sales_order", summary="Multi-table sales order dataset", supportsStreaming=True, |
@@ -65,7 +68,7 @@ class MultiTableSalesOrderProvider(DatasetProvider): |
65 | 68 | INVOICE_MIN_VALUE = 1_000_000 |
66 | 69 |
|
67 | 70 | def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCustomers: int, dummyValues: int) -> DataGenerator: |
68 | | - import dbldatagen as dg # noqa: PLC0415 |
| 71 | + import dbldatagen as dg # noqa: PLC0415 |
69 | 72 |
|
70 | 73 | # Validate the options: |
71 | 74 | if numCustomers is None or numCustomers < 0: |
@@ -106,7 +109,7 @@ def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int |
106 | 109 | return customers_data_spec |
107 | 110 |
|
108 | 111 | def getCarriers(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCarriers: int, dummyValues: int) -> DataGenerator: |
109 | | - import dbldatagen as dg # noqa: PLC0415 |
| 112 | + import dbldatagen as dg # noqa: PLC0415 |
110 | 113 |
|
111 | 114 | # Validate the options: |
112 | 115 | if numCarriers is None or numCarriers < 0: |
@@ -143,7 +146,7 @@ def getCarriers(self, sparkSession: SparkSession, *, rows: int, partitions: int, |
143 | 146 | return carriers_data_spec |
144 | 147 |
|
145 | 148 | def getCatalogItems(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCatalogItems: int, dummyValues: int) -> DataGenerator: |
146 | | - import dbldatagen as dg # noqa: PLC0415 |
| 149 | + import dbldatagen as dg # noqa: PLC0415 |
147 | 150 |
|
148 | 151 | # Validate the options: |
149 | 152 | if numCatalogItems is None or numCatalogItems < 0: |
@@ -184,7 +187,7 @@ def getCatalogItems(self, sparkSession: SparkSession, *, rows: int, partitions: |
184 | 187 |
|
185 | 188 | def getBaseOrders(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, numCustomers: int, startDate: str, |
186 | 189 | endDate: str, dummyValues: int) -> DataGenerator: |
187 | | - import dbldatagen as dg # noqa: PLC0415 |
| 190 | + import dbldatagen as dg # noqa: PLC0415 |
188 | 191 |
|
189 | 192 | # Validate the options: |
190 | 193 | if numOrders is None or numOrders < 0: |
@@ -231,7 +234,7 @@ def getBaseOrders(self, sparkSession: SparkSession, *, rows: int, partitions: in |
231 | 234 |
|
232 | 235 | def getBaseOrderLineItems(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, numCatalogItems: int, |
233 | 236 | lineItemsPerOrder: int, dummyValues: int) -> DataGenerator: |
234 | | - import dbldatagen as dg # noqa: PLC0415 |
| 237 | + import dbldatagen as dg # noqa: PLC0415 |
235 | 238 |
|
236 | 239 | # Validate the options: |
237 | 240 | if numOrders is None or numOrders < 0: |
@@ -268,9 +271,6 @@ def getBaseOrderLineItems(self, sparkSession: SparkSession, *, rows: int, partit |
268 | 271 | return base_order_line_items_data_spec |
269 | 272 |
|
270 | 273 | def getBaseOrderShipments(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, numCarriers: int, dummyValues: int) -> DataGenerator: |
271 | | - # ruff: noqa: I001 |
272 | | - import dbldatagen as dg # noqa: PLC0415 |
273 | | - |
274 | 274 | # Validate the options: |
275 | 275 | if numOrders is None or numOrders < 0: |
276 | 276 | numOrders = self.DEFAULT_NUM_ORDERS |
@@ -311,10 +311,6 @@ def getBaseOrderShipments(self, sparkSession: SparkSession, *, rows: int, partit |
311 | 311 | return base_order_shipments_data_spec |
312 | 312 |
|
313 | 313 | def getBaseInvoices(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, dummyValues: int) -> DataGenerator: |
314 | | - # ruff: noqa: I001 |
315 | | - import dbldatagen as dg # noqa: PLC0415 |
316 | | - |
317 | | - |
318 | 314 | # Validate the options: |
319 | 315 | if numOrders is None or numOrders < 0: |
320 | 316 | numOrders = self.DEFAULT_NUM_ORDERS |
@@ -354,7 +350,7 @@ def getBaseInvoices(self, sparkSession: SparkSession, *, rows: int, partitions: |
354 | 350 |
|
355 | 351 | @DatasetProvider.allowed_options(options=["numCustomers", "numCarriers", "numCatalogItems", "numOrders", |
356 | 352 | "lineItemsPerOrder", "startDate", "endDate", "dummyValues"]) |
357 | | - def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator: |
| 353 | + def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: |
358 | 354 | # Get the option values: |
359 | 355 | numCustomers = options.get("numCustomers", self.DEFAULT_NUM_CUSTOMERS) |
360 | 356 | numCarriers = options.get("numCarriers", self.DEFAULT_NUM_CARRIERS) |
@@ -443,11 +439,7 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N |
443 | 439 | "baseOrderShipments", |
444 | 440 | "baseInvoices" |
445 | 441 | ]) |
446 | | - def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator: |
447 | | - # ruff: noqa: I001 |
448 | | - from pyspark.sql import DataFrame # noqa: PLC0415 |
449 | | - import pyspark.sql.functions as F # noqa: PLC0415 |
450 | | - |
| 442 | + def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator: |
451 | 443 | dfCustomers = options.get("customers") |
452 | 444 | assert dfCustomers is not None and issubclass(type(dfCustomers), DataFrame), \ |
453 | 445 | "Option `customers` should be a dataframe of customer records" |
|
0 commit comments