Skip to content

Commit 309650c

Browse files
authored
Added AggregateMatch rule (#663)
* Implement AggregateMatch rule * Add edge case tests
1 parent 1b7a037 commit 309650c

File tree

6 files changed

+379
-3
lines changed

6 files changed

+379
-3
lines changed

src/main/scala/com/amazon/deequ/dqdl/execution/DQDLExecutor.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
package com.amazon.deequ.dqdl.execution
1818

19-
import com.amazon.deequ.dqdl.execution.executors.{ColumnNamesMatchPatternExecutor, CompositeRulesExecutor, DataFreshnessExecutor, DatasetMatchExecutor, DeequRulesExecutor, ReferentialIntegrityExecutor, RowCountMatchExecutor, UnsupportedRulesExecutor}
20-
import com.amazon.deequ.dqdl.model.{ColumnNamesMatchPatternExecutableRule, CompositeExecutableRule, DataFreshnessExecutableRule, DatasetMatchExecutableRule, DeequExecutableRule, ExecutableRule, Failed, ReferentialIntegrityExecutableRule, RowCountMatchExecutableRule, RuleOutcome, UnsupportedExecutableRule}
19+
import com.amazon.deequ.dqdl.execution.executors.{AggregateMatchExecutor, ColumnNamesMatchPatternExecutor, CompositeRulesExecutor, DataFreshnessExecutor, DatasetMatchExecutor, DeequRulesExecutor, ReferentialIntegrityExecutor, RowCountMatchExecutor, UnsupportedRulesExecutor}
20+
import com.amazon.deequ.dqdl.model.{AggregateMatchExecutableRule, ColumnNamesMatchPatternExecutableRule, CompositeExecutableRule, DataFreshnessExecutableRule, DatasetMatchExecutableRule, DeequExecutableRule, ExecutableRule, Failed, ReferentialIntegrityExecutableRule, RowCountMatchExecutableRule, RuleOutcome, UnsupportedExecutableRule}
2121
import org.apache.spark.sql.DataFrame
2222
import software.amazon.glue.dqdl.model.DQRule
2323

@@ -41,7 +41,8 @@ object DQDLExecutor {
4141
classOf[ReferentialIntegrityExecutableRule] -> ReferentialIntegrityExecutor,
4242
classOf[DataFreshnessExecutableRule] -> DataFreshnessExecutor,
4343
classOf[ColumnNamesMatchPatternExecutableRule] -> ColumnNamesMatchPatternExecutor,
44-
classOf[DatasetMatchExecutableRule] -> DatasetMatchExecutor
44+
classOf[DatasetMatchExecutableRule] -> DatasetMatchExecutor,
45+
classOf[AggregateMatchExecutableRule] -> AggregateMatchExecutor
4546
)
4647

4748
def executeRules(rules: Seq[ExecutableRule], df: DataFrame,
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/**
2+
* Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
5+
* use this file except in compliance with the License. A copy of the License
6+
* is located at
7+
*
8+
* http://aws.amazon.com/apache2.0/
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed on
11+
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*
15+
*/
16+
17+
package com.amazon.deequ.dqdl.execution.executors
18+
19+
import com.amazon.deequ.dqdl.execution.DQDLExecutor
20+
import com.amazon.deequ.dqdl.model.{AggregateMatchExecutableRule, AggregateOperation, Avg, Failed, Passed, RuleOutcome, Sum}
21+
import org.apache.spark.sql.DataFrame
22+
import org.apache.spark.sql.functions.{avg, sum}
23+
import org.apache.spark.sql.types.DoubleType
24+
import software.amazon.glue.dqdl.model.DQRule
25+
26+
import scala.util.{Failure, Success, Try}
27+
28+
object AggregateMatchExecutor extends DQDLExecutor.RuleExecutor[AggregateMatchExecutableRule] {
29+
30+
private val PrimaryAlias = "primary"
31+
32+
override def executeRules(rules: Seq[AggregateMatchExecutableRule], df: DataFrame,
33+
additionalDataSources: Map[String, DataFrame] = Map.empty): Map[DQRule, RuleOutcome] = {
34+
val dataSources = additionalDataSources + (PrimaryAlias -> df)
35+
rules.map { rule =>
36+
rule.dqRule -> evaluateRule(rule, dataSources)
37+
}.toMap
38+
}
39+
40+
private def evaluateRule(rule: AggregateMatchExecutableRule,
41+
dataSources: Map[String, DataFrame]): RuleOutcome = {
42+
val result = for {
43+
first <- evaluateAggregate(rule.firstAggregateOperation, dataSources)
44+
second <- evaluateAggregate(rule.secondAggregateOperation, dataSources)
45+
} yield computeRatio(first, second)
46+
47+
result match {
48+
case Right(ratio) =>
49+
val metricName = rule.evaluatedMetricName.get
50+
val metrics = Map(metricName -> ratio)
51+
if (rule.assertion(ratio)) {
52+
RuleOutcome(rule.dqRule, Passed, None, metrics)
53+
} else {
54+
RuleOutcome(rule.dqRule, Failed,
55+
Some(s"Value: $ratio does not meet the constraint requirement."), metrics)
56+
}
57+
case Left(errorMsg) =>
58+
RuleOutcome(rule.dqRule, Failed, Some(errorMsg))
59+
}
60+
}
61+
62+
private def evaluateAggregate(op: AggregateOperation,
63+
dataSources: Map[String, DataFrame]): Either[String, Double] = {
64+
dataSources.get(op.dataSourceAlias) match {
65+
case Some(ds) =>
66+
val colOp = op match {
67+
case Avg(_, _) => avg(op.column).cast(DoubleType)
68+
case Sum(_, _) => sum(op.column).cast(DoubleType)
69+
}
70+
Try(ds.select(colOp).collect().head.getAs[Double](0)) match {
71+
case Success(v) => Right(v)
72+
case Failure(ex) => Left(s"Exception: ${ex.getClass.getName}")
73+
}
74+
case None =>
75+
Left(s"${op.dataSourceAlias} not found in additional sources")
76+
}
77+
}
78+
79+
private def computeRatio(first: Double, second: Double): Double = {
80+
if (first == 0 && second == 0) 1.0 else if (second == 0) 0.0 else first / second
81+
}
82+
}

src/main/scala/com/amazon/deequ/dqdl/model/ExecutableRule.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,26 @@ case class CompositeExecutableRule(dqRule: DQRule,
100100
override val evaluatedMetricName: Option[String] = None
101101
}
102102

103+
sealed trait AggregateOperation {
104+
def dataSourceAlias: String
105+
def column: String
106+
}
107+
108+
case class Sum(dataSourceAlias: String, column: String) extends AggregateOperation
109+
case class Avg(dataSourceAlias: String, column: String) extends AggregateOperation
110+
111+
case class AggregateMatchExecutableRule(dqRule: DQRule,
112+
firstAggregateOperation: AggregateOperation,
113+
secondAggregateOperation: AggregateOperation,
114+
assertion: Double => Boolean) extends ExecutableRule {
115+
override val evaluatedMetricName: Option[String] = {
116+
val col1 = firstAggregateOperation.column
117+
val col2 = secondAggregateOperation.column
118+
val instance = if (col1 == col2) col1 else s"$col1,$col2"
119+
Some(s"Column.$instance.AggregateMatch")
120+
}
121+
}
122+
103123
case class DeequMetricMapping(entity: String,
104124
instance: String,
105125
name: String,

src/main/scala/com/amazon/deequ/dqdl/translation/DQDLRuleTranslator.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ import com.amazon.deequ.dqdl.translation.rules.ReferentialIntegrityRule
3939
import com.amazon.deequ.dqdl.translation.rules.DatasetMatchRule
4040
import com.amazon.deequ.dqdl.translation.rules.DataFreshnessRule
4141
import com.amazon.deequ.dqdl.translation.rules.ColumnNamesMatchPatternRule
42+
import com.amazon.deequ.dqdl.translation.rules.AggregateMatchRule
4243
import software.amazon.glue.dqdl.model.DQRule
4344
import software.amazon.glue.dqdl.model.DQRuleset
4445

@@ -102,6 +103,7 @@ object DQDLRuleTranslator {
102103
}
103104
case "RowCountMatch" => RowCountMatchRule.toExecutableRule(rule)
104105
case "ColumnNamesMatchPattern" => ColumnNamesMatchPatternRule.toExecutableRule(rule)
106+
case "AggregateMatch" => AggregateMatchRule.toExecutableRule(rule)
105107
case "ReferentialIntegrity" =>
106108
ReferentialIntegrityRule.toExecutableRule(rule) match {
107109
case Right(executableRule) => executableRule
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/**
2+
* Copyright 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
5+
* use this file except in compliance with the License. A copy of the License
6+
* is located at
7+
*
8+
* http://aws.amazon.com/apache2.0/
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed on
11+
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*
15+
*/
16+
17+
package com.amazon.deequ.dqdl.translation.rules
18+
19+
import com.amazon.deequ.dqdl.execution.DefaultOperandEvaluator
20+
import com.amazon.deequ.dqdl.model.{AggregateMatchExecutableRule, AggregateOperation, Avg, Sum, UnsupportedExecutableRule}
21+
import com.amazon.deequ.dqdl.model.ExecutableRule
22+
import software.amazon.glue.dqdl.model.DQRule
23+
import software.amazon.glue.dqdl.model.condition.number.NumberBasedCondition
24+
25+
import scala.collection.JavaConverters._
26+
import scala.util.matching.Regex
27+
28+
object AggregateMatchRule {
29+
30+
private val PrimaryAlias = "primary"
31+
// sum(colA), sum(reference-1.colA), avg(colB), avg(db.table.colB) etc.
32+
private val operationRegex: Regex = "(sum|avg|SUM|AVG)\\((.*)\\)".r
33+
34+
def toExecutableRule(rule: DQRule): ExecutableRule = {
35+
val aggregateExpression1 = rule.getParameters.asScala("AggregateExpression1")
36+
val aggregateExpression2 = rule.getParameters.asScala("AggregateExpression2")
37+
val condition = rule.getCondition.asInstanceOf[NumberBasedCondition]
38+
val assertion: Double => Boolean =
39+
(d: Double) => condition.evaluate(d, rule, DefaultOperandEvaluator)
40+
41+
val am = for {
42+
aggOp1 <- parseAggregateOperation(aggregateExpression1)
43+
aggOp2 <- parseAggregateOperation(aggregateExpression2)
44+
} yield AggregateMatchExecutableRule(rule, aggOp1, aggOp2, assertion)
45+
46+
am match {
47+
case Some(r) => r
48+
case _ => UnsupportedExecutableRule(rule, Some("Unsupported Rule"))
49+
}
50+
}
51+
52+
private def parseCol(c: String): (String, String) = {
53+
val splitAt = c.lastIndexOf(".")
54+
if (splitAt == -1) (c.replaceAll("\"", ""), PrimaryAlias)
55+
else {
56+
val column = c.substring(splitAt + 1)
57+
val datasourceAlias = c.substring(0, splitAt)
58+
(column.replaceAll("\"", ""), datasourceAlias)
59+
}
60+
}
61+
62+
private def parseAggregateOperation(op: String): Option[AggregateOperation] = {
63+
op match {
64+
case operationRegex(aggOp, col) =>
65+
val (aggCol, refAlias) = parseCol(col)
66+
aggOp.toLowerCase match {
67+
case "avg" => Some(Avg(refAlias, aggCol))
68+
case "sum" => Some(Sum(refAlias, aggCol))
69+
case _ => None
70+
}
71+
case _ => None
72+
}
73+
}
74+
}

0 commit comments

Comments
 (0)