Skip to content

Commit 59274a5

Browse files
authored
Updated Spark 3.3 dependency (#196)
This commit updates the Spark 3.3 dependency of Deequ. There are some breaking changes to the Scala APIs, from a Py4J perspective. In order to work around that, we use the Spark version to switch between the updated API and the old API. This is not sustainable and will be revisited in a future PR, or via a different release mechanism. The issue is that we have multiple branches for multiple Spark versions in Deequ, but only one branch in PyDeequ. The changes were verified by running the tests in Docker against Spark version 3.3. The docker file was also updated so that it copies over the pyproject.toml file and installs dependencies in a separate layer, before the code is copied. This allows for fast iteration of the code, without the need to install dependencies every time the docker image is built.
1 parent 4bb727b commit 59274a5

File tree

5 files changed

+87
-19
lines changed

5 files changed

+87
-19
lines changed

Dockerfile

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ RUN pip3 --version
1616
RUN java -version
1717
RUN pip install poetry==1.7.1
1818

19-
COPY . /python-deequ
19+
RUN mkdir python-deequ
20+
COPY pyproject.toml /python-deequ
21+
COPY poetry.lock /python-deequ
2022
WORKDIR python-deequ
2123

22-
RUN poetry lock --no-update
23-
RUN poetry install
24-
RUN poetry add pyspark==3.3
24+
RUN poetry install -vvv
25+
RUN poetry add pyspark==3.3 -vvv
2526

2627
ENV SPARK_VERSION=3.3
28+
COPY . /python-deequ
2729
CMD poetry run python -m pytest -s tests

pydeequ/analyzers.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pydeequ.repository import MetricsRepository, ResultKey
1111
from enum import Enum
1212
from pydeequ.scala_utils import to_scala_seq
13-
13+
from pydeequ.configs import SPARK_VERSION
1414

1515
class _AnalyzerObject:
1616
"""
@@ -303,7 +303,19 @@ def _analyzer_jvm(self):
303303
304304
:return self
305305
"""
306-
return self._deequAnalyzers.Compliance(self.instance, self.predicate, self._jvm.scala.Option.apply(self.where))
306+
if SPARK_VERSION == "3.3":
307+
return self._deequAnalyzers.Compliance(
308+
self.instance,
309+
self.predicate,
310+
self._jvm.scala.Option.apply(self.where),
311+
self._jvm.scala.collection.Seq.empty()
312+
)
313+
else:
314+
return self._deequAnalyzers.Compliance(
315+
self.instance,
316+
self.predicate,
317+
self._jvm.scala.Option.apply(self.where)
318+
)
307319

308320

309321
class Correlation(_AnalyzerObject):
@@ -457,12 +469,22 @@ def _analyzer_jvm(self):
457469
"""
458470
if not self.maxDetailBins:
459471
self.maxDetailBins = getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$3")()
460-
return self._deequAnalyzers.Histogram(
461-
self.column,
462-
self._jvm.scala.Option.apply(self.binningUdf),
463-
self.maxDetailBins,
464-
self._jvm.scala.Option.apply(self.where),
465-
)
472+
if SPARK_VERSION == "3.3":
473+
return self._deequAnalyzers.Histogram(
474+
self.column,
475+
self._jvm.scala.Option.apply(self.binningUdf),
476+
self.maxDetailBins,
477+
self._jvm.scala.Option.apply(self.where),
478+
getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$5")(),
479+
getattr(self._jvm.com.amazon.deequ.analyzers.Histogram, "apply$default$6")()
480+
)
481+
else:
482+
return self._deequAnalyzers.Histogram(
483+
self.column,
484+
self._jvm.scala.Option.apply(self.binningUdf),
485+
self.maxDetailBins,
486+
self._jvm.scala.Option.apply(self.where)
487+
)
466488

467489

468490
class KLLParameters:
@@ -553,7 +575,17 @@ def _analyzer_jvm(self):
553575
554576
:return self
555577
"""
556-
return self._deequAnalyzers.MaxLength(self.column, self._jvm.scala.Option.apply(self.where))
578+
if SPARK_VERSION == "3.3":
579+
return self._deequAnalyzers.MaxLength(
580+
self.column,
581+
self._jvm.scala.Option.apply(self.where),
582+
self._jvm.scala.Option.apply(None)
583+
)
584+
else:
585+
return self._deequAnalyzers.MaxLength(
586+
self.column,
587+
self._jvm.scala.Option.apply(self.where)
588+
)
557589

558590

559591
class Mean(_AnalyzerObject):
@@ -619,7 +651,17 @@ def _analyzer_jvm(self):
619651
620652
:return self
621653
"""
622-
return self._deequAnalyzers.MinLength(self.column, self._jvm.scala.Option.apply(self.where))
654+
if SPARK_VERSION == "3.3":
655+
return self._deequAnalyzers.MinLength(
656+
self.column,
657+
self._jvm.scala.Option.apply(self.where),
658+
self._jvm.scala.Option.apply(None)
659+
)
660+
else:
661+
return self._deequAnalyzers.MinLength(
662+
self.column,
663+
self._jvm.scala.Option.apply(self.where)
664+
)
623665

624666

625667
class MutualInformation(_AnalyzerObject):

pydeequ/checks.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pydeequ.check_functions import is_one
88
from pydeequ.scala_utils import ScalaFunction1, to_scala_seq
9-
9+
from pydeequ.configs import SPARK_VERSION
1010

1111
# TODO implement custom assertions
1212
# TODO implement all methods without outside class dependencies
@@ -418,7 +418,11 @@ def hasMinLength(self, column, assertion, hint=None):
418418
"""
419419
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
420420
hint = self._jvm.scala.Option.apply(hint)
421-
self._Check = self._Check.hasMinLength(column, assertion_func, hint)
421+
if SPARK_VERSION == "3.3":
422+
self._Check = self._Check.hasMinLength(column, assertion_func, hint, self._jvm.scala.Option.apply(None))
423+
else:
424+
self._Check = self._Check.hasMinLength(column, assertion_func, hint)
425+
422426
return self
423427

424428
def hasMaxLength(self, column, assertion, hint=None):
@@ -433,7 +437,10 @@ def hasMaxLength(self, column, assertion, hint=None):
433437
"""
434438
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
435439
hint = self._jvm.scala.Option.apply(hint)
436-
self._Check = self._Check.hasMaxLength(column, assertion_func, hint)
440+
if SPARK_VERSION == "3.3":
441+
self._Check = self._Check.hasMaxLength(column, assertion_func, hint, self._jvm.scala.Option.apply(None))
442+
else:
443+
self._Check = self._Check.hasMaxLength(column, assertion_func, hint)
437444
return self
438445

439446
def hasMin(self, column, assertion, hint=None):
@@ -558,7 +565,21 @@ def satisfies(self, columnCondition, constraintName, assertion=None, hint=None):
558565
else getattr(self._Check, "satisfies$default$3")()
559566
)
560567
hint = self._jvm.scala.Option.apply(hint)
561-
self._Check = self._Check.satisfies(columnCondition, constraintName, assertion_func, hint)
568+
if SPARK_VERSION == "3.3":
569+
self._Check = self._Check.satisfies(
570+
columnCondition,
571+
constraintName,
572+
assertion_func,
573+
hint,
574+
self._jvm.scala.collection.Seq.empty()
575+
)
576+
else:
577+
self._Check = self._Check.satisfies(
578+
columnCondition,
579+
constraintName,
580+
assertion_func,
581+
hint
582+
)
562583
return self
563584

564585
def hasPattern(self, column, pattern, assertion=None, name=None, hint=None):

pydeequ/configs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
SPARK_TO_DEEQU_COORD_MAPPING = {
8-
"3.3": "com.amazon.deequ:deequ:2.0.3-spark-3.3",
8+
"3.3": "com.amazon.deequ:deequ:2.0.4-spark-3.3",
99
"3.2": "com.amazon.deequ:deequ:2.0.1-spark-3.2",
1010
"3.1": "com.amazon.deequ:deequ:2.0.0-spark-3.1",
1111
"3.0": "com.amazon.deequ:deequ:1.2.2-spark-3.0",
@@ -40,5 +40,6 @@ def _get_deequ_maven_config():
4040
)
4141

4242

43+
SPARK_VERSION = _get_spark_version()
4344
DEEQU_MAVEN_COORD = _get_deequ_maven_config()
4445
IS_DEEQU_V1 = re.search("com\.amazon\.deequ\:deequ\:1.*", DEEQU_MAVEN_COORD) is not None

pydeequ/profiles.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ def __init__(self, spark_session: SparkSession):
241241
self._profiles = []
242242
self.columnProfileClasses = {
243243
"StandardColumnProfile": StandardColumnProfile,
244+
"StringColumnProfile": StandardColumnProfile,
244245
"NumericColumnProfile": NumericColumnProfile,
246+
245247
}
246248

247249
def _columnProfilesFromColumnRunBuilderRun(self, run):

0 commit comments

Comments
 (0)