Skip to content

Commit ef6e6af

Browse files
jasonlin0189impvJasonEthan_C_Lin
authored
Modify 'PatternMatch' function in analyzers.py and 'hasPattern' function in checks.py (#66)
* Add code to func: hasPattern, cuz there is no code in ver1.0.0 * modify PatternMatch func, cuz previous regex will not match anything * add test cases for analyzer.PatternMatch & checker.hasPattern --------- Co-authored-by: Jason <[email protected]> Co-authored-by: Ethan_C_Lin <[email protected]>
1 parent aff4be6 commit ef6e6af

File tree

4 files changed

+40
-3
lines changed

4 files changed

+40
-3
lines changed

pydeequ/analyzers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ class PatternMatch(_AnalyzerObject):
664664

665665
def __init__(self, column, pattern_regex: str, *pattern_groupNames, where: str = None):
666666
self.column = column
667-
self.pattern_regex = (pattern_regex,)
667+
self.pattern_regex = pattern_regex
668668
if pattern_groupNames:
669669
raise NotImplementedError("pattern_groupNames have not been implemented yet.")
670670
self.pattern_groupNames = None
@@ -679,7 +679,7 @@ def _analyzer_jvm(self):
679679
"""
680680
return self._deequAnalyzers.PatternMatch(
681681
self.column,
682-
self._jvm.scala.util.matching.Regex(str(self.pattern_regex), None),
682+
self._jvm.scala.util.matching.Regex(self.pattern_regex, None),
683683
# TODO: revisit bc scala constructor does some weird implicit type casting from python str -> java list
684684
# if we don't cast it to str()
685685
self._jvm.scala.Option.apply(self.where),

pydeequ/checks.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,13 @@ def hasPattern(self, column, pattern, assertion=None, name=None, hint=None):
564564
:param str hint: A hint that states why a constraint could have failed.
565565
:return: hasPattern self: A Check object that runs the condition on the column.
566566
"""
567+
assertion_func = ScalaFunction1(self._spark_session.sparkContext._gateway, assertion) if assertion \
568+
else getattr(self._Check, "hasPattern$default$2")()
569+
name = self._jvm.scala.Option.apply(name)
570+
hint = self._jvm.scala.Option.apply(hint)
571+
pattern_regex = self._jvm.scala.util.matching.Regex(pattern, None)
572+
self._Check = self._Check.hasPattern(column, pattern_regex, assertion_func, name, hint)
573+
return self
567574

568575
def containsCreditCardNumber(self, column, assertion=None, hint=None):
569576
"""

tests/test_analyzers.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,8 +493,18 @@ def test_MutualInformation(self):
493493
def test_fail_MutualInformation(self):
494494
self.assertEqual(self.MutualInformation(["b", "d"]), [])
495495

496-
# TODO: Revisit when PatternMatch class is sorted out
497496
def test_PatternMatch(self):
497+
result = (
498+
self.AnalysisRunner.onData(self.df).addAnalyzer(PatternMatch(column="a", pattern_regex="ba(r|z)")).run()
499+
)
500+
result_df = AnalyzerContext.successMetricsAsDataFrame(self.spark, result)
501+
result_json = AnalyzerContext.successMetricsAsJson(self.spark, result)
502+
df_from_json = self.spark.read.json(self.sc.parallelize([result_json]))
503+
self.assertEqual(df_from_json.select("value").collect(), result_df.select("value").collect())
504+
self.assertEqual(result_df.select("value").collect(), [Row(value=0.6666666666666666)])
505+
506+
@pytest.mark.xfail(reason="@unittest.expectedFailure")
507+
def test_fail_PatternMatch(self):
498508
result = (
499509
self.AnalysisRunner.onData(self.df).addAnalyzer(PatternMatch(column="a", pattern_regex="ba(r|z)")).run()
500510
)

tests/test_checks.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,18 @@ def satisfies(self, columnCondition, constraintName, assertion=None, hint=None):
395395
df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
396396
return df.select("constraint_status").collect()
397397

398+
def hasPattern(self, column, pattern, assertion=None, name=None, hint=None):
399+
check = Check(self.spark, CheckLevel.Warning, "test hasPattern")
400+
result = (
401+
VerificationSuite(self.spark)
402+
.onData(self.df)
403+
.addCheck(check.hasPattern(column, pattern, assertion, name, hint))
404+
.run()
405+
)
406+
407+
df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
408+
return df.select("constraint_status").collect()
409+
398410
def isLessThanOrEqualTo(self, columnA, columnB, assertion=None, hint=None):
399411
check = Check(self.spark, CheckLevel.Warning, "test isLessThanOrEqualTo")
400412
result = (
@@ -1120,6 +1132,14 @@ def test_fail_satisfies(self):
11201132
)
11211133
self.assertEqual(self.satisfies('a = "zoo"', "find a", lambda x: x == 1), [Row(constraint_status="Success")])
11221134

1135+
def test_hasPattern(self):
1136+
self.assertEqual(self.hasPattern("ssn", "\d{3}\-\d{2}\-\d{4}", lambda x: x == 2 / 3), [Row(constraint_status="Success")])
1137+
1138+
@pytest.mark.xfail(reason="@unittest.expectedFailure")
1139+
def test_fail_hasPattern(self):
1140+
self.assertEqual(self.hasPattern("ssn", r"\d{3}\-\d{2}\-\d{4}", lambda x: x == 2 / 3), [Row(constraint_status="Failure")])
1141+
self.assertEqual(self.hasPattern("ssn", r"\d{3}\d{2}\d{4}", lambda x: x == 2 / 3), [Row(constraint_status="Failure")])
1142+
11231143
def test_isNonNegative(self):
11241144
self.assertEqual(self.isNonNegative("b"), [Row(constraint_status="Success")])
11251145
self.assertEqual(

0 commit comments

Comments
 (0)