Skip to content

Commit e74e974

Browse files
authored
fix: add assertion and hints for isContainedIn and hasPattern; Add is_one and use as default assertion (#157)
1 parent 04e2634 commit e74e974

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

pydeequ/check_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def is_one(x):
2+
return x == 1 / 1

pydeequ/checks.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
from pyspark.sql import SparkSession
55

6+
from pydeequ.check_functions import is_one
67
from pydeequ.scala_utils import ScalaFunction1, to_scala_seq
78

9+
810
# TODO implement custom assertions
911
# TODO implement all methods without outside class dependencies
1012
# TODO Integration with Constraints
@@ -564,8 +566,10 @@ def hasPattern(self, column, pattern, assertion=None, name=None, hint=None):
564566
:param str hint: A hint that states why a constraint could have failed.
565567
:return: hasPattern self: A Check object that runs the condition on the column.
566568
"""
567-
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) if assertion \
568-
else getattr(self._Check, "hasPattern$default$2")()
569+
if not assertion:
570+
assertion = is_one
571+
572+
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
569573
name = self._jvm.scala.Option.apply(name)
570574
hint = self._jvm.scala.Option.apply(hint)
571575
pattern_regex = self._jvm.scala.util.matching.Regex(pattern, None)
@@ -779,19 +783,25 @@ def isGreaterThanOrEqualTo(self, columnA, columnB, assertion=None, hint=None):
779783
self._Check = self._Check.isGreaterThanOrEqualTo(columnA, columnB, assertion_func, hint)
780784
return self
781785

782-
def isContainedIn(self, column, allowed_values):
786+
def isContainedIn(self, column, allowed_values, assertion=None, hint=None):
783787
"""
784788
Asserts that every non-null value in a column is contained in a set of predefined values
785789
786790
:param str column: Column in DataFrame to run the assertion on.
787791
:param list[str] allowed_values: A function that accepts allowed values for the column.
792+
:param lambda assertion: A function that accepts an int or float parameter.
788793
:param str hint: A hint that states why a constraint could have failed.
789794
:return: isContainedIn self: A Check object that runs the assertion on the columns.
790795
"""
791796
arr = self._spark_session.sparkContext._gateway.new_array(self._jvm.java.lang.String, len(allowed_values))
792797
for i in range(len(allowed_values)):
793798
arr[i] = allowed_values[i]
794-
self._Check = self._Check.isContainedIn(column, arr)
799+
800+
if not assertion:
801+
assertion = is_one
802+
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion)
803+
hint = self._jvm.scala.Option.apply(hint)
804+
self._Check = self._Check.isContainedIn(column, arr, assertion_func, hint)
795805
return self
796806

797807
def evaluate(self, context):

tests/test_checks.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -431,10 +431,12 @@ def isGreaterThan(self, columnA, columnB, assertion=None, hint=None):
431431
df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
432432
return df.select("constraint_status").collect()
433433

434-
def isContainedIn(self, column, allowed_values):
434+
def isContainedIn(self, column, allowed_values, assertion=None, hint=None):
435435
check = Check(self.spark, CheckLevel.Warning, "test isContainedIn")
436436
result = (
437-
VerificationSuite(self.spark).onData(self.df).addCheck(check.isContainedIn(column, allowed_values)).run()
437+
VerificationSuite(self.spark).onData(self.df).addCheck(
438+
check.isContainedIn(column, allowed_values, assertion=assertion, hint=hint)
439+
).run()
438440
)
439441

440442
df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
@@ -1134,6 +1136,11 @@ def test_fail_satisfies(self):
11341136

11351137
def test_hasPattern(self):
11361138
self.assertEqual(self.hasPattern("ssn", "\d{3}\-\d{2}\-\d{4}", lambda x: x == 2 / 3), [Row(constraint_status="Success")])
1139+
# Default assertion is 1, thus failure
1140+
self.assertEqual(self.hasPattern("ssn", "\d{3}\-\d{2}\-\d{4}"), [Row(constraint_status="Failure")])
1141+
self.assertEqual(
1142+
self.hasPattern("ssn", "\d{3}\-\d{2}\-\d{4}", lambda x: x == 2 / 3, hint="it be should be above 0.66"),
1143+
[Row(constraint_status="Success")])
11371144

11381145
@pytest.mark.xfail(reason="@unittest.expectedFailure")
11391146
def test_fail_hasPattern(self):
@@ -1206,6 +1213,12 @@ def test_fail_isGreaterThan(self):
12061213
self.assertEqual(self.isGreaterThan("h", "f", lambda x: x == 1), [Row(constraint_status="Success")])
12071214

12081215
def test_isContainedIn(self):
1216+
# test all variants for assertion and hint
1217+
self.assertEqual(
1218+
self.isContainedIn("a", ["foo", "bar", "baz"], lambda x: x == 1), [Row(constraint_status="Success")])
1219+
self.assertEqual(
1220+
self.isContainedIn("a", ["foo", "bar", "baz"], lambda x: x == 1, hint="it should be 1"),
1221+
[Row(constraint_status="Success")])
12091222
self.assertEqual(self.isContainedIn("a", ["foo", "bar", "baz"]), [Row(constraint_status="Success")])
12101223
# A none value makes the test still pass
12111224
self.assertEqual(self.isContainedIn("c", ["5", "6"]), [Row(constraint_status="Success")])

0 commit comments

Comments
 (0)