Skip to content

Commit da3cfec

Browse files
committed
removing ann101/2 as they are redundant, remove some ignores, remove relative imports
1 parent 39e9d44 commit da3cfec

10 files changed

+56
-79
lines changed

dbldatagen/datasets/basic_geometries.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import ClassVar
1+
import warnings as w
2+
from typing import Any, ClassVar
23

34
from pyspark.sql import SparkSession
45

6+
import dbldatagen as dg
57
from dbldatagen.data_generator import DataGenerator
6-
7-
from .dataset_provider import DatasetProvider, dataset_definition
8+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
89

910

1011
@dataset_definition(name="basic/geometries",
@@ -51,11 +52,7 @@ class BasicGeometriesProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
5152
]
5253

5354
@DatasetProvider.allowed_options(options=ALLOWED_OPTIONS)
54-
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator:
55-
# ruff: noqa: I001
56-
import dbldatagen as dg # noqa: PLC0415
57-
import warnings as w # noqa: PLC0415
58-
55+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
5956
generateRandom = options.get("random", False)
6057
geometryType = options.get("geometryType", "point")
6158
maxVertices = options.get("maxVertices", 1 if geometryType == "point" else 3)

dbldatagen/datasets/basic_process_historian.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import ClassVar
1+
from typing import Any, ClassVar
22

3+
import numpy as np
34
from pyspark.sql import SparkSession
45

6+
import dbldatagen as dg
57
from dbldatagen.data_generator import DataGenerator
6-
7-
from .dataset_provider import DatasetProvider, dataset_definition
8+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
89

910

1011
@dataset_definition(name="basic/process_historian",
@@ -60,10 +61,8 @@ class BasicProcessHistorianProvider(DatasetProvider.NoAssociatedDatasetsMixin, D
6061
]
6162

6263
@DatasetProvider.allowed_options(options=ALLOWED_OPTIONS)
63-
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator:
64-
# ruff: noqa: I001
65-
import dbldatagen as dg # noqa: PLC0415 # import locally to avoid circular imports
66-
import numpy as np # noqa: PLC0415
64+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
65+
6766

6867
generateRandom = options.get("random", False)
6968
numDevices = options.get("numDevices", self.DEFAULT_NUM_DEVICES)

dbldatagen/datasets/basic_stock_ticker.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from typing import ClassVar
21
from random import random
2+
from typing import ClassVar
33

44
from pyspark.sql import SparkSession
55

6+
import dbldatagen as dg
67
from dbldatagen.data_generator import DataGenerator
7-
8-
from .dataset_provider import DatasetProvider, dataset_definition
8+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
99

1010

1111
@dataset_definition(name="basic/stock_ticker",
@@ -43,8 +43,6 @@ class BasicStockTickerProvider(DatasetProvider.NoAssociatedDatasetsMixin, Datase
4343

4444
@DatasetProvider.allowed_options(options=ALLOWED_OPTIONS)
4545
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator:
46-
# ruff: noqa: I001
47-
import dbldatagen as dg # noqa: PLC0415
4846

4947
numSymbols = options.get("numSymbols", self.DEFAULT_NUM_SYMBOLS)
5048
startDate = options.get("startDate", self.DEFAULT_START_DATE)

dbldatagen/datasets/basic_telematics.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from typing import ClassVar
1+
import warnings as w
2+
from typing import Any, ClassVar
3+
24
from pyspark.sql import SparkSession
35

6+
import dbldatagen as dg
47
from dbldatagen.data_generator import DataGenerator
5-
6-
from .dataset_provider import DatasetProvider, dataset_definition
8+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
79

810

911
@dataset_definition(name="basic/telematics",
@@ -60,11 +62,7 @@ class BasicTelematicsProvider(DatasetProvider.NoAssociatedDatasetsMixin, Dataset
6062
]
6163

6264
@DatasetProvider.allowed_options(options=ALLOWED_OPTIONS)
63-
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator:
64-
# ruff: noqa: I001
65-
import warnings as w # noqa: PLC0415
66-
67-
import dbldatagen as dg # noqa: PLC0415
65+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
6866

6967
generateRandom = options.get("random", False)
7068
numDevices = options.get("numDevices", self.DEFAULT_NUM_DEVICES)

dbldatagen/datasets/basic_user.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from typing import Any
2+
13
from pyspark.sql import SparkSession
24

5+
import dbldatagen as dg
36
from dbldatagen.data_generator import DataGenerator
4-
5-
from .dataset_provider import DatasetProvider, dataset_definition
7+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
68

79

810
@dataset_definition(name="basic/user", summary="Basic User Data Set", autoRegister=True, supportsStreaming=True)
@@ -31,10 +33,7 @@ class BasicUserProvider(DatasetProvider.NoAssociatedDatasetsMixin, DatasetProvid
3133
COLUMN_COUNT = 5
3234

3335
@DatasetProvider.allowed_options(options=["random", "dummyValues"])
34-
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator:
35-
# ruff: noqa: I001
36-
import dbldatagen as dg # noqa: PLC0415
37-
36+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
3837
generateRandom = options.get("random", False)
3938
dummyValues = options.get("dummyValues", 0)
4039

dbldatagen/datasets/benchmark_groupby.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from typing import ClassVar
1+
import warnings as w
2+
from typing import Any, ClassVar
3+
24
from pyspark.sql import SparkSession
35

6+
import dbldatagen as dg
47
from dbldatagen.data_generator import DataGenerator
5-
6-
from .dataset_provider import DatasetProvider, dataset_definition
8+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
79

810

911
@dataset_definition(name="benchmark/groupby",
@@ -39,10 +41,7 @@ class BenchmarkGroupByProvider(DatasetProvider.NoAssociatedDatasetsMixin, Datase
3941
ALLOWED_OPTIONS: ClassVar[list[str]] = ["groups", "percentNulls", "rows", "partitions", "tableName", "random"]
4042

4143
@DatasetProvider.allowed_options(options=ALLOWED_OPTIONS)
42-
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator:
43-
# ruff: noqa: I001
44-
import dbldatagen as dg # noqa: PLC0415
45-
import warnings as w # noqa: PLC0415
44+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
4645

4746
generateRandom = options.get("random", False)
4847
groups = options.get("groups", self.DEFAULT_NUM_GROUPS)

dbldatagen/datasets/dataset_provider.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def getRegisteredDatasetsVersion(cls) -> int :
187187
return cls._registeredDatasetsVersion
188188

189189
@abstractmethod
190-
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator:
190+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
191191
"""Gets data generation instance that will produce table for named table
192192
193193
:param sparkSession: Spark session to use
@@ -207,7 +207,7 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N
207207

208208
@abstractmethod
209209
def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1,
210-
**options: object) -> DataGenerator:
210+
**options: dict[str, Any]) -> DataGenerator:
211211
"""
212212
Gets associated datasets that are used in conjunction with the provider datasets.
213213
These may be associated lookup tables, tables that execute benchmarks or exercise key features as part of
@@ -288,7 +288,7 @@ class NoAssociatedDatasetsMixin(ABC): # noqa: B024
288288
any associated datasets
289289
"""
290290
def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int =-1,
291-
**options: object) -> DataGenerator:
291+
**options: dict[str, Any]) -> DataGenerator:
292292
raise NotImplementedError("Data provider does not produce any associated datasets!")
293293

294294
class DatasetDecoratorUtils:
@@ -381,7 +381,7 @@ def mkClass(self, autoRegister: bool =False) -> type:
381381
return retval
382382

383383

384-
def dataset_definition(cls: type|None =None, *args: Any, autoRegister: bool =False, **kwargs: Any) -> type: # pylint: disable=keyword-arg-before-vararg # noqa: ANN401
384+
def dataset_definition(cls: type|None =None, *args: object, autoRegister: bool =False, **kwargs: object) -> type:
385385
""" decorator to define standard dataset definition
386386
387387
This is intended to be applied classes derived from DatasetProvider to simplify the implementation
@@ -414,7 +414,7 @@ class X(DatasetProvider)
414414
415415
"""
416416

417-
def inner_wrapper(inner_cls: type|None =None, *inner_args: Any, **inner_kwargs) -> type: # pylint: disable=keyword-arg-before-vararg # noqa: ANN401
417+
def inner_wrapper(inner_cls: type|None =None, *inner_args: object, **inner_kwargs: object) -> type:
418418
""" The inner wrapper function is used to handle the case where the decorator is used with arguments.
419419
It defers the application of the decorator to the target class until the target class is available.
420420

dbldatagen/datasets/multi_table_sales_order_provider.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from pyspark.sql import SparkSession
1+
from typing import Any
22

3-
from dbldatagen.data_generator import DataGenerator
3+
import pyspark.sql.functions as F
4+
from pyspark.sql import DataFrame, SparkSession
45

5-
from .dataset_provider import DatasetProvider, dataset_definition
6+
import dbldatagen as dg
7+
from dbldatagen.data_generator import DataGenerator
8+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
69

710

811
@dataset_definition(name="multi_table/sales_order", summary="Multi-table sales order dataset", supportsStreaming=True,
@@ -65,7 +68,7 @@ class MultiTableSalesOrderProvider(DatasetProvider):
6568
INVOICE_MIN_VALUE = 1_000_000
6669

6770
def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCustomers: int, dummyValues: int) -> DataGenerator:
68-
import dbldatagen as dg # noqa: PLC0415
71+
import dbldatagen as dg # noqa: PLC0415
6972

7073
# Validate the options:
7174
if numCustomers is None or numCustomers < 0:
@@ -106,7 +109,7 @@ def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int
106109
return customers_data_spec
107110

108111
def getCarriers(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCarriers: int, dummyValues: int) -> DataGenerator:
109-
import dbldatagen as dg # noqa: PLC0415
112+
import dbldatagen as dg # noqa: PLC0415
110113

111114
# Validate the options:
112115
if numCarriers is None or numCarriers < 0:
@@ -143,7 +146,7 @@ def getCarriers(self, sparkSession: SparkSession, *, rows: int, partitions: int,
143146
return carriers_data_spec
144147

145148
def getCatalogItems(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCatalogItems: int, dummyValues: int) -> DataGenerator:
146-
import dbldatagen as dg # noqa: PLC0415
149+
import dbldatagen as dg # noqa: PLC0415
147150

148151
# Validate the options:
149152
if numCatalogItems is None or numCatalogItems < 0:
@@ -184,7 +187,7 @@ def getCatalogItems(self, sparkSession: SparkSession, *, rows: int, partitions:
184187

185188
def getBaseOrders(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, numCustomers: int, startDate: str,
186189
endDate: str, dummyValues: int) -> DataGenerator:
187-
import dbldatagen as dg # noqa: PLC0415
190+
import dbldatagen as dg # noqa: PLC0415
188191

189192
# Validate the options:
190193
if numOrders is None or numOrders < 0:
@@ -231,7 +234,7 @@ def getBaseOrders(self, sparkSession: SparkSession, *, rows: int, partitions: in
231234

232235
def getBaseOrderLineItems(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, numCatalogItems: int,
233236
lineItemsPerOrder: int, dummyValues: int) -> DataGenerator:
234-
import dbldatagen as dg # noqa: PLC0415
237+
import dbldatagen as dg # noqa: PLC0415
235238

236239
# Validate the options:
237240
if numOrders is None or numOrders < 0:
@@ -268,9 +271,6 @@ def getBaseOrderLineItems(self, sparkSession: SparkSession, *, rows: int, partit
268271
return base_order_line_items_data_spec
269272

270273
def getBaseOrderShipments(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, numCarriers: int, dummyValues: int) -> DataGenerator:
271-
# ruff: noqa: I001
272-
import dbldatagen as dg # noqa: PLC0415
273-
274274
# Validate the options:
275275
if numOrders is None or numOrders < 0:
276276
numOrders = self.DEFAULT_NUM_ORDERS
@@ -311,10 +311,6 @@ def getBaseOrderShipments(self, sparkSession: SparkSession, *, rows: int, partit
311311
return base_order_shipments_data_spec
312312

313313
def getBaseInvoices(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, dummyValues: int) -> DataGenerator:
314-
# ruff: noqa: I001
315-
import dbldatagen as dg # noqa: PLC0415
316-
317-
318314
# Validate the options:
319315
if numOrders is None or numOrders < 0:
320316
numOrders = self.DEFAULT_NUM_ORDERS
@@ -354,7 +350,7 @@ def getBaseInvoices(self, sparkSession: SparkSession, *, rows: int, partitions:
354350

355351
@DatasetProvider.allowed_options(options=["numCustomers", "numCarriers", "numCatalogItems", "numOrders",
356352
"lineItemsPerOrder", "startDate", "endDate", "dummyValues"])
357-
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator:
353+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
358354
# Get the option values:
359355
numCustomers = options.get("numCustomers", self.DEFAULT_NUM_CUSTOMERS)
360356
numCarriers = options.get("numCarriers", self.DEFAULT_NUM_CARRIERS)
@@ -443,11 +439,7 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N
443439
"baseOrderShipments",
444440
"baseInvoices"
445441
])
446-
def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator:
447-
# ruff: noqa: I001
448-
from pyspark.sql import DataFrame # noqa: PLC0415
449-
import pyspark.sql.functions as F # noqa: PLC0415
450-
442+
def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
451443
dfCustomers = options.get("customers")
452444
assert dfCustomers is not None and issubclass(type(dfCustomers), DataFrame), \
453445
"Option `customers` should be a dataframe of customer records"

dbldatagen/datasets/multi_table_telephony_provider.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from typing import Any
12
from pyspark.sql import SparkSession
23

34
from dbldatagen.data_generator import DataGenerator
45

5-
from .dataset_provider import DatasetProvider, dataset_definition
6+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
7+
import dbldatagen as dg
68

79

810
@dataset_definition(name="multi_table/telephony", summary="Multi-table telephony dataset", supportsStreaming=True,
@@ -60,7 +62,6 @@ class MultiTableTelephonyProvider(DatasetProvider):
6062
DEFAULT_AVG_EVENTS_PER_CUSTOMER = 50
6163

6264
def getPlans(self, sparkSession: SparkSession, *, rows: int, partitions: int, generateRandom: bool, numPlans: int, dummyValues: int) -> DataGenerator:
63-
import dbldatagen as dg # noqa: PLC0415
6465

6566
if numPlans is None or numPlans < 0:
6667
numPlans = self.DEFAULT_NUM_PLANS
@@ -219,7 +220,7 @@ def getDeviceEvents(self, sparkSession: SparkSession, *, rows: int, partitions:
219220

220221
@DatasetProvider.allowed_options(options=["random", "numPlans", "numCustomers", "dummyValues", "numDays",
221222
"averageEventsPerCustomer"])
222-
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: object) -> DataGenerator:
223+
def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1, **options: dict[str, Any]) -> DataGenerator:
223224
generateRandom = options.get("random", False)
224225
numPlans = options.get("numPlans", self.DEFAULT_NUM_PLANS)
225226
numCustomers = options.get("numCustomers", self.DEFAULT_NUM_CUSTOMERS)
@@ -240,7 +241,7 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N
240241

241242
@DatasetProvider.allowed_options(options=["plans", "customers", "deviceEvents"])
242243
def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1,
243-
**options: object) -> DataGenerator:
244+
**options: dict[str, Any]) -> DataGenerator:
244245
# ruff: noqa: I001
245246
import pyspark.sql.functions as F #noqa: PLC0415
246247
from pyspark.sql import DataFrame #noqa: PLC0415
@@ -297,11 +298,7 @@ def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|Non
297298
""")
298299

299300
df_summary.createOrReplaceTempView("mtp_event_summary")
300-
301-
df_customer_summary = ( # noqa: F841
302-
df_customer_pricing.join(df_summary,
303-
df_customer_pricing.device_id == df_summary.device_id)
304-
.createOrReplaceTempView("mtp_customer_summary"))
301+
df_customer_pricing.join(df_summary,df_customer_pricing.device_id == df_summary.device_id).createOrReplaceTempView("mtp_customer_summary")
305302

306303
df_invoices = sparkSession.sql("""
307304
select *,

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ packages = ["dbldatagen"]
6565
[tool.hatch.build.targets.sdist]
6666
include = [
6767
"/dbldatagen",
68-
"/tests",
68+
"/tests",
6969
"/examples",
7070
"/tutorial",
7171
"/docs",
@@ -185,8 +185,6 @@ ignore = [
185185
"SIM102", # Use a single if-statement
186186
"SIM108", # Use ternary operator
187187
"UP007", # Use X | Y for type annotations (keep Union for compatibility)
188-
"ANN101", # Missing type annotation for `self` in method
189-
"ANN102", # Missing type annotation for `cls` in method
190188
"ANN003", # Missing type annotation for **kwargs
191189
]
192190

0 commit comments

Comments
 (0)