Skip to content

Commit 67786c9

Browse files
committed
Update dependencies, tools, and tests
1 parent 9095018 commit 67786c9

File tree

10 files changed

+80
-61
lines changed

10 files changed

+80
-61
lines changed

.github/workflows/push.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
run: pip install hatch
6060

6161
- name: Run unit tests
62-
run: make test
62+
run: make dev test
6363

6464
- name: Publish test coverage to coverage site
6565
uses: codecov/codecov-action@v4

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,22 @@ All notable changes to the Databricks Labs Data Generator will be documented in
77

88
#### Fixed
99
* Updated build scripts to use Ubuntu 22.04 to correspond to environment in Databricks runtime
10+
* Refactored `DataAnalyzer` and `BasicStockTickerProvider` to comply with ANSI SQL standards
11+
* Removed internal modification of `SparkSession`
1012

1113
#### Changed
1214
* Changed base Databricks runtime version to DBR 13.3 LTS (based on Apache Spark 3.4.1) - minimum supported version
1315
of Python is now 3.10.12
16+
* Updated build tooling to use [hatch](https://hatch.pypa.io/latest/)
17+
* Moved dependencies and tool configuration to [pyproject.toml](pyproject.toml)
18+
* Removed dependencies provided by the Databricks Runtime
19+
* Updated Git actions
20+
* Updated [makefile](makefile)
21+
* Updated [CONTRIBUTING.md](CONTRIBUTING.md)
1422

1523
#### Added
1624
* Added support for serialization to/from JSON format
25+
* Added Ruff and mypy tooling
1726

1827

1928
### Version 0.4.0 Hotfix 2

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,7 @@ runtimes.
168168
By design, installing `dbldatagen` does not install releases of dependent packages in order
169169
to preserve the curated set of packages pre-installed in any Databricks runtime environment.
170170

171-
When building on local environments, the build process uses the `Pipfile` and requirements files to determine
172-
the package versions for releases and unit tests.
171+
When building on local environments, run `make dev` to install required dependencies.
173172

174173
## Project Support
175174
Please note that all projects released under [`Databricks Labs`](https://www.databricks.com/learn/labs)

dbldatagen/column_generation_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ def _getSeedExpression(self, base_column):
943943
else:
944944
return col(base_column[0])
945945
elif self._baseColumnComputeMethod == VALUES_COMPUTE_METHOD:
946-
base_values = [f"string(ifnull(`{x}`, 'null'))" for x in base_column]
946+
base_values = [f"string(ifnull(`{x}`, cast(null as string)))" for x in base_column]
947947
return expr(f"array({','.join(base_values)})")
948948
else:
949949
return expr(f"hash({','.join(base_column)})")

dbldatagen/data_analyzer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,13 @@ def summarizeToDF(self):
226226
# string characteristics for strings and string representation of other values
227227
dfDataSummary = self._addMeasureToSummary(
228228
'print_len_min',
229-
fieldExprs=[f"min(length(string({dtype[0]}))) as {dtype[0]}" for dtype in dtypes],
229+
fieldExprs=[f"string(min(length(string({dtype[0]})))) as {dtype[0]}" for dtype in dtypes],
230230
dfData=self._df,
231231
dfSummary=dfDataSummary)
232232

233233
dfDataSummary = self._addMeasureToSummary(
234234
'print_len_max',
235-
fieldExprs=[f"max(length(string({dtype[0]}))) as {dtype[0]}" for dtype in dtypes],
235+
fieldExprs=[f"string(max(length(string({dtype[0]})))) as {dtype[0]}" for dtype in dtypes],
236236
dfData=self._df,
237237
dfSummary=dfDataSummary)
238238

dbldatagen/datasets/basic_stock_ticker.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
6060
baseColumn="symbol_id", omit=True)
6161
.withColumn("symbol", "string",
6262
expr="""concat_ws('', transform(split(conv(symbol_id, 10, 26), ''),
63-
x -> case when x < 10 then char(ascii(x) - 48 + 65) else char(ascii(x) + 10) end))""")
64-
.withColumn("days_from_start_date", "int", expr=f"floor(id / {numSymbols})", omit=True)
63+
x -> case when ascii(x) < 10 then char(ascii(x) - 48 + 65) else char(ascii(x) + 10) end))""")
64+
.withColumn("days_from_start_date", "int", expr=f"floor(try_divide(id, {numSymbols}))", omit=True)
6565
.withColumn("post_date", "date", expr=f"date_add(cast('{startDate}' as date), days_from_start_date)")
6666
.withColumn("start_value", "decimal(11,2)",
67-
values=[1.0 + 199.0 * random() for _ in range(int(numSymbols / 10))], omit=True)
68-
.withColumn("growth_rate", "float", values=[-0.1 + 0.35 * random() for _ in range(int(numSymbols / 10))],
67+
values=[1.0 + 199.0 * random() for _ in range(max(1, int(numSymbols / 10)))], omit=True)
68+
.withColumn("growth_rate", "float", values=[-0.1 + 0.35 * random() for _ in range(max(1, int(numSymbols / 10)))],
6969
baseColumn="symbol_id")
70-
.withColumn("volatility", "float", values=[0.0075 * random() for _ in range(int(numSymbols / 10))],
70+
.withColumn("volatility", "float", values=[0.0075 * random() for _ in range(max(1, int(numSymbols / 10)))],
7171
baseColumn="symbol_id", omit=True)
7272
.withColumn("prev_modifier_sign", "float",
7373
expr=f"case when sin((id - {numSymbols}) % 17) > 0 then -1.0 else 1.0 end""",
@@ -78,12 +78,12 @@ def getTableGenerator(self, sparkSession, *, tableName=None, rows=-1, partitions
7878
.withColumn("open_base", "decimal(11,2)",
7979
expr=f"""start_value
8080
+ (volatility * prev_modifier_sign * start_value * sin((id - {numSymbols}) % 17))
81-
+ (growth_rate * start_value * (days_from_start_date - 1) / 365)""",
81+
+ (growth_rate * start_value * try_divide(days_from_start_date - 1, 365))""",
8282
omit=True)
8383
.withColumn("close_base", "decimal(11,2)",
8484
expr="""start_value
8585
+ (volatility * start_value * sin(id % 17))
86-
+ (growth_rate * start_value * days_from_start_date / 365)""",
86+
+ (growth_rate * start_value * try_divide(days_from_start_date, 365))""",
8787
omit=True)
8888
.withColumn("high_base", "decimal(11,2)",
8989
expr="greatest(open_base, close_base) + rand() * volatility * open_base",

dbldatagen/text_generators.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -856,8 +856,8 @@ def generateText(self, baseValues, rowCount=1):
856856
# hardening a mask prevents masked values from being changed
857857
np.ma.harden_mask(masked_offsets)
858858
# Cast offsets to the same dtype as the array to avoid casting errors
859-
capitals_offset = word_offset_type.type(self._startOfCapitalsOffset)
860-
spaced_words_offset = word_offset_type.type(self._startOfSpacedWordsOffset)
859+
capitals_offset = self._wordOffsetType.type(self._startOfCapitalsOffset)
860+
spaced_words_offset = self._wordOffsetType.type(self._startOfSpacedWordsOffset)
861861
masked_offsets[:, :, :, 0] = masked_offsets[:, :, :, 0] + capitals_offset
862862
masked_offsets[:, :, :, 1:] = masked_offsets[:, :, :, 1:] + spaced_words_offset
863863
np.ma.soften_mask(masked_offsets)
@@ -869,7 +869,7 @@ def generateText(self, baseValues, rowCount=1):
869869
new_col = new_word_offsets[:, :, :, np.newaxis]
870870
terminated_word_offsets = np.ma.concatenate((masked_offsets, new_col), axis=3)
871871
new_column = terminated_word_offsets[:, :, :, -1]
872-
sentence_end_offset = word_offset_type.type(self._sentenceEndOffset)
872+
sentence_end_offset = self._wordOffsetType.type(self._sentenceEndOffset)
873873
new_column[~new_column.mask] = sentence_end_offset
874874

875875
# reshape to paragraphs
@@ -887,7 +887,7 @@ def generateText(self, baseValues, rowCount=1):
887887
# set the paragraph end marker on all paragraphs except last
888888
# new_masked_elements = terminated_paragraph_offsets[:,:,-1]
889889
new_column = terminated_paragraph_offsets[:, :, -1]
890-
paragraph_end_offset = word_offset_type.type(self._paragraphEnd)
890+
paragraph_end_offset = self._wordOffsetType.type(self._paragraphEnd)
891891
new_column[~new_column.mask] = paragraph_end_offset
892892
else:
893893
terminated_paragraph_offsets = paragraph_offsets
@@ -897,7 +897,7 @@ def generateText(self, baseValues, rowCount=1):
897897
shape = terminated_paragraph_offsets.shape
898898
terminated_paragraph_offsets = terminated_paragraph_offsets.reshape((rowCount, shape[1] * shape[2]))
899899

900-
empty_string_offset = word_offset_type.type(self._emptyStringOffset)
900+
empty_string_offset = self._wordOffsetType.type(self._emptyStringOffset)
901901
final_data = terminated_paragraph_offsets.filled(fill_value=empty_string_offset)
902902

903903
# its faster to manipulate text in data frames as numpy strings are fixed length

pyproject.toml

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,7 @@ classifiers = [
2929
]
3030
dependencies = [
3131
"databricks-sdk~=0.57",
32-
"numpy>=1.22.0",
33-
"pandas>=1.3.4",
34-
"pyarrow>=7.0.0",
35-
"pyspark[sql]>=3.3.0",
36-
"python-dateutil>=2.8.2",
37-
"six>=1.16.0",
38-
"pyparsing>=3.0.4",
39-
"jmespath>=0.10.0",
40-
"py4j>=0.10.9",
41-
"pickleshare>=0.7.5",
32+
"py4j>=0.10.9"
4233
]
4334

4435
[project.urls]
@@ -49,14 +40,6 @@ Homepage = "https://github.com/databrickslabs/dbldatagen"
4940
Repository = "https://github.com/databrickslabs/dbldatagen.git"
5041

5142
[project.optional-dependencies]
52-
dev = [
53-
"pytest>=6.0.0",
54-
"pytest-cov>=3.0.0",
55-
"pytest-timeout",
56-
"ruff>=0.1.0",
57-
"pylint>=2.15.0",
58-
"mypy>=1.0.0",
59-
]
6043
docs = [
6144
"sphinx>=7.0.0",
6245
"sphinx-rtd-theme",
@@ -110,7 +93,17 @@ dependencies = [
11093
"ruff~=0.3.4",
11194
"types-PyYAML~=6.0.12",
11295
"types-requests~=2.31.0",
113-
"pyspark[sql]~=3.5.0"
96+
"databricks-sdk~=0.57",
97+
"numpy>=1.21.5",
98+
"pandas>=1.4.4",
99+
"pyarrow>=8.0.0",
100+
"pyspark[sql]>=3.4.1",
101+
"python-dateutil>=2.8.2",
102+
"six>=1.16.0",
103+
"pyparsing>=3.0.9",
104+
"jmespath>=0.10.0",
105+
"py4j>=0.10.9",
106+
"pickleshare>=0.7.5",
114107
]
115108

116109
python="3.10"

tests/test_complex_columns.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -523,11 +523,11 @@ def test_inferred_column_structs1(self, setupLogging):
523523
df = df_spec.build()
524524

525525
type1 = self.getFieldType(df.schema, "struct1")
526-
expectedType = StructType([StructField('a', IntegerType()), StructField('b', IntegerType())])
526+
expectedType = StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)])
527527
assert type1 == expectedType
528528

529529
type2 = self.getFieldType(df.schema, "struct2")
530-
expectedType2 = StructType([StructField('a', DateType(), False), StructField('b', StringType())])
530+
expectedType2 = StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)])
531531
assert type2 == expectedType2
532532

533533
def test_inferred_column_structs2(self, setupLogging):
@@ -551,13 +551,13 @@ def test_inferred_column_structs2(self, setupLogging):
551551
df = df_spec.build()
552552

553553
type1 = self.getFieldType(df.schema, "struct1")
554-
assert type1 == StructType([StructField('a', IntegerType()), StructField('b', IntegerType())])
554+
assert type1 == StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)])
555555
type2 = self.getFieldType(df.schema, "struct2")
556-
assert type2 == StructType([StructField('a', DateType(), False), StructField('b', StringType())])
556+
assert type2 == StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)])
557557
type3 = self.getFieldType(df.schema, "struct3")
558558
assert type3 == StructType(
559-
[StructField('a', StructType([StructField('a', IntegerType()), StructField('b', IntegerType())]), False),
560-
StructField('b', StructType([StructField('a', DateType(), False), StructField('b', StringType())]), False)]
559+
[StructField('a', StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)]), False),
560+
StructField('b', StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]), False)]
561561
)
562562

563563
def test_with_struct_column1(self, setupLogging):
@@ -580,9 +580,9 @@ def test_with_struct_column1(self, setupLogging):
580580
df = df_spec.build()
581581

582582
type1 = self.getFieldType(df.schema, "struct1")
583-
assert type1 == StructType([StructField('a', IntegerType()), StructField('b', IntegerType())])
583+
assert type1 == StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)])
584584
type2 = self.getFieldType(df.schema, "struct2")
585-
assert type2 == StructType([StructField('a', DateType(), False), StructField('b', StringType())])
585+
assert type2 == StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)])
586586

587587
def test_with_struct_column2(self, setupLogging):
588588
column_count = 10
@@ -604,9 +604,9 @@ def test_with_struct_column2(self, setupLogging):
604604
df = df_spec.build()
605605

606606
type1 = self.getFieldType(df.schema, "struct1")
607-
assert type1 == StructType([StructField('code1', IntegerType()), StructField('code2', IntegerType())])
607+
assert type1 == StructType([StructField('code1', IntegerType(), True), StructField('code2', IntegerType(), True)])
608608
type2 = self.getFieldType(df.schema, "struct2")
609-
assert type2 == StructType([StructField('code5', DateType(), False), StructField('code6', StringType())])
609+
assert type2 == StructType([StructField('code5', DateType(), False), StructField('code6', StringType(), False)])
610610

611611
def test_with_json_struct_column(self, setupLogging):
612612
column_count = 10
@@ -680,13 +680,13 @@ def test_with_struct_column3(self, setupLogging):
680680
df = df_spec.build()
681681

682682
type1 = self.getFieldType(df.schema, "struct1")
683-
assert type1 == StructType([StructField('a', IntegerType()), StructField('b', IntegerType())])
683+
assert type1 == StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)])
684684
type2 = self.getFieldType(df.schema, "struct2")
685-
assert type2 == StructType([StructField('a', DateType(), False), StructField('b', StringType())])
685+
assert type2 == StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)])
686686
type3 = self.getFieldType(df.schema, "struct3")
687687
assert type3 == StructType(
688-
[StructField('a', StructType([StructField('a', IntegerType()), StructField('b', IntegerType())]), False),
689-
StructField('b', StructType([StructField('a', DateType(), False), StructField('b', StringType())]),
688+
[StructField('a', StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)]), False),
689+
StructField('b', StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]),
690690
False)])
691691

692692
def test_with_struct_column4(self, setupLogging):
@@ -711,13 +711,13 @@ def test_with_struct_column4(self, setupLogging):
711711
df = df_spec.build()
712712

713713
type1 = self.getFieldType(df.schema, "struct1")
714-
assert type1 == StructType([StructField('a', IntegerType()), StructField('b', IntegerType())])
714+
assert type1 == StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)])
715715
type2 = self.getFieldType(df.schema, "struct2")
716-
assert type2 == StructType([StructField('a', DateType(), False), StructField('b', StringType())])
716+
assert type2 == StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)])
717717
type3 = self.getFieldType(df.schema, "struct3")
718718
assert type3 == StructType(
719-
[StructField('a', StructType([StructField('a', IntegerType()), StructField('b', IntegerType())]), False),
720-
StructField('b', StructType([StructField('a', DateType(), False), StructField('b', StringType())]),
719+
[StructField('a', StructType([StructField('a', IntegerType(), True), StructField('b', IntegerType(), True)]), False),
720+
StructField('b', StructType([StructField('a', DateType(), False), StructField('b', StringType(), False)]),
721721
False)])
722722

723723
def test_with_struct_column_err1(self, setupLogging):

tests/test_serverless.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,28 @@ def serverlessSpark(self):
2424

2525
oldSetMethod = sparkSession.conf.set
2626
oldGetMethod = sparkSession.conf.get
27-
sparkSession.conf.set = MagicMock(
28-
side_effect=ValueError("Setting value prohibited in simulated serverless env."))
29-
sparkSession.conf.get = MagicMock(
30-
side_effect=ValueError("Getting value prohibited in simulated serverless env."))
27+
def mock_conf_set(*args, **kwargs):
28+
raise ValueError("Setting value prohibited in simulated serverless env.")
29+
30+
def mock_conf_get(config_key, default=None):
31+
# Allow internal PySpark configuration calls that are needed for basic operation
32+
whitelisted_configs = {
33+
'spark.sql.stackTracesInDataFrameContext': '1',
34+
'spark.sql.execution.arrow.enabled': 'false',
35+
'spark.sql.execution.arrow.pyspark.enabled': 'false',
36+
'spark.python.sql.dataFrameDebugging.enabled': 'true',
37+
'spark.sql.execution.arrow.maxRecordsPerBatch': '10000'
38+
}
39+
if config_key in whitelisted_configs:
40+
try:
41+
return oldGetMethod(config_key, whitelisted_configs[config_key])
42+
except:
43+
return whitelisted_configs[config_key]
44+
else:
45+
raise ValueError("Getting value prohibited in simulated serverless env.")
46+
47+
sparkSession.conf.set = MagicMock(side_effect=mock_conf_set)
48+
sparkSession.conf.get = MagicMock(side_effect=mock_conf_get)
3149

3250
yield sparkSession
3351

@@ -59,7 +77,7 @@ def test_basic_data(self, serverlessSpark):
5977
)
6078
)
6179

62-
dfTestData = testDataSpec.build()
80+
testDataSpec.build()
6381

6482
@pytest.mark.parametrize("providerName, providerOptions", [
6583
("basic/user", {"rows": 50, "partitions": 4, "random": False, "dummyValues": 0}),
@@ -72,4 +90,4 @@ def test_basic_user_table_retrieval(self, providerName, providerOptions, serverl
7290
"""
7391
df = ds.build()
7492

75-
assert df.count() >= 0
93+
assert df.count() >= 0

0 commit comments

Comments
 (0)