Skip to content

Commit 5688fa8

Browse files
committed
Correcting return type
Removing noqa: PLC0415
1 parent 9ce27c9 commit 5688fa8

File tree

3 files changed

+6
-24
lines changed

3 files changed

+6
-24
lines changed

dbldatagen/datasets/dataset_provider.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def wrapper(*args, **kwargs) -> Callable: # noqa: ANN002
254254

255255
return decorator
256256

257-
def checkOptions(self, options: dict[str, Any], allowedOptions: list[str]) -> None:
257+
def checkOptions(self, options: dict[str, Any], allowedOptions: list[str]) -> DatasetDefinition:
258258
""" Check that options are valid
259259
260260
:param options: options to check as dict

dbldatagen/datasets/multi_table_sales_order_provider.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ class MultiTableSalesOrderProvider(DatasetProvider):
6868
INVOICE_MIN_VALUE = 1_000_000
6969

7070
def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCustomers: int, dummyValues: int) -> DataGenerator:
71-
import dbldatagen as dg # noqa: PLC0415
72-
7371
# Validate the options:
7472
if numCustomers is None or numCustomers < 0:
7573
numCustomers = self.DEFAULT_NUM_CUSTOMERS
@@ -109,8 +107,6 @@ def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int
109107
return customers_data_spec
110108

111109
def getCarriers(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCarriers: int, dummyValues: int) -> DataGenerator:
112-
import dbldatagen as dg # noqa: PLC0415
113-
114110
# Validate the options:
115111
if numCarriers is None or numCarriers < 0:
116112
numCarriers = self.DEFAULT_NUM_CARRIERS
@@ -146,9 +142,6 @@ def getCarriers(self, sparkSession: SparkSession, *, rows: int, partitions: int,
146142
return carriers_data_spec
147143

148144
def getCatalogItems(self, sparkSession: SparkSession, *, rows: int, partitions: int, numCatalogItems: int, dummyValues: int) -> DataGenerator:
149-
import dbldatagen as dg # noqa: PLC0415
150-
151-
# Validate the options:
152145
if numCatalogItems is None or numCatalogItems < 0:
153146
numCatalogItems = self.DEFAULT_NUM_CATALOG_ITEMS
154147
if rows is None or rows < 0:
@@ -187,8 +180,6 @@ def getCatalogItems(self, sparkSession: SparkSession, *, rows: int, partitions:
187180

188181
def getBaseOrders(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, numCustomers: int, startDate: str,
189182
endDate: str, dummyValues: int) -> DataGenerator:
190-
import dbldatagen as dg # noqa: PLC0415
191-
192183
# Validate the options:
193184
if numOrders is None or numOrders < 0:
194185
numOrders = self.DEFAULT_NUM_ORDERS
@@ -234,9 +225,6 @@ def getBaseOrders(self, sparkSession: SparkSession, *, rows: int, partitions: in
234225

235226
def getBaseOrderLineItems(self, sparkSession: SparkSession, *, rows: int, partitions: int, numOrders: int, numCatalogItems: int,
236227
lineItemsPerOrder: int, dummyValues: int) -> DataGenerator:
237-
import dbldatagen as dg # noqa: PLC0415
238-
239-
# Validate the options:
240228
if numOrders is None or numOrders < 0:
241229
numOrders = self.DEFAULT_NUM_ORDERS
242230
if numCatalogItems is None or numCatalogItems < 0:

dbldatagen/datasets/multi_table_telephony_provider.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from typing import Any
2-
from pyspark.sql import SparkSession
32

4-
from dbldatagen.data_generator import DataGenerator
3+
import pyspark.sql.functions as F
4+
from pyspark.sql import DataFrame, SparkSession
55

6-
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
76
import dbldatagen as dg
7+
from dbldatagen.data_generator import DataGenerator
8+
from dbldatagen.datasets.dataset_provider import DatasetProvider, dataset_definition
89

910

1011
@dataset_definition(name="multi_table/telephony", summary="Multi-table telephony dataset", supportsStreaming=True,
@@ -107,8 +108,6 @@ def getPlans(self, sparkSession: SparkSession, *, rows: int, partitions: int, ge
107108
return plan_dataspec
108109

109110
def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int, generateRandom: bool, numCustomers: int, numPlans: int, dummyValues: int) -> DataGenerator:
110-
import dbldatagen as dg # noqa: PLC0415
111-
112111
if numCustomers is None or numCustomers < 0:
113112
numCustomers = self.DEFAULT_NUM_CUSTOMERS
114113

@@ -149,7 +148,6 @@ def getCustomers(self, sparkSession: SparkSession, *, rows: int, partitions: int
149148

150149
def getDeviceEvents(self, sparkSession: SparkSession, *, rows: int, partitions: int, generateRandom: bool, numCustomers: int, numDays: int, dummyValues: int,
151150
averageEventsPerCustomer: int) -> DataGenerator:
152-
import dbldatagen as dg # noqa: PLC0415
153151
MB_100 = 100 * 1000 * 1000
154152
K_1 = 1000
155153

@@ -242,13 +240,9 @@ def getTableGenerator(self, sparkSession: SparkSession, *, tableName: str|None=N
242240
@DatasetProvider.allowed_options(options=["plans", "customers", "deviceEvents"])
243241
def getAssociatedDataset(self, sparkSession: SparkSession, *, tableName: str|None=None, rows: int=-1, partitions: int=-1,
244242
**options: dict[str, Any]) -> DataGenerator:
245-
# ruff: noqa: I001
246-
import pyspark.sql.functions as F #noqa: PLC0415
247-
from pyspark.sql import DataFrame #noqa: PLC0415
248243

249244
dfPlans = options.get("plans")
250-
assert dfPlans is not None and issubclass(type(dfPlans), DataFrame), \
251-
"Option `plans` should be a dataframe of plan records"
245+
assert dfPlans is not None and issubclass(type(dfPlans), DataFrame), "Option `plans` should be a dataframe of plan records"
252246

253247
dfCustomers = options.get("customers")
254248
assert dfCustomers is not None and issubclass(type(dfCustomers), DataFrame), \

0 commit comments

Comments
 (0)