Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ trait Analyzer[S <: State[_], +M <: Metric[_]] extends Serializable {
*/
def computeMetricFrom(state: Option[S]): M

/**
* Returns the columns this analyzer reads from the data, if known statically.
* Returns Some(columns) when all referenced columns can be determined,
* or None when the analyzer may reference arbitrary columns (e.g. free-form SQL predicates).
* Used by AnalysisRunner to enable column pruning for V2 DataSource connectors like Iceberg.
*/
def columnsReferenced(): Option[Set[String]] = None

/**
* A set of assertions that must hold on the schema of the data frame
* @return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,7 @@ case class ApproxCountDistinct(column: String, where: Option[String] = None)
}

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,7 @@ case class ApproxQuantile(
}

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,6 @@ case class ApproxQuantiles(column: String, quantiles: Seq[Double], relativeError
override def preconditions: Seq[StructType => Unit] = {
PARAM_CHECKS :: hasColumn(column) :: isNumeric(column) :: Nil
}

override def columnsReferenced(): Option[Set[String]] = Some(Set(column))
}
2 changes: 2 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,6 @@ case class ColumnCount() extends Analyzer[NumMatches, DoubleMetric] {
override private[deequ] def toFailureMetric(failure: Exception): DoubleMetric = {
Analyzers.metricFromFailure(failure, name, instance, entity)
}

override def columnsReferenced(): Option[Set[String]] = Some(Set.empty)
}
2 changes: 2 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/ColumnExists.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,6 @@ case class ColumnExists(column: String) extends Analyzer[ColumnExistsState, Doub
override private[deequ] def toFailureMetric(failure: Exception): DoubleMetric = {
Analyzers.metricFromFailure(failure, name, instance, entity)
}

override def columnsReferenced(): Option[Set[String]] = Some(Set.empty)
}
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/Completeness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ case class Completeness(column: String, where: Option[String] = None,

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))

@VisibleForTesting // required by some tests that compare analyzer results to an expected state
private[deequ] def criterion: Column = conditionalSelection(column, where).isNotNull

Expand Down
4 changes: 4 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/Compliance.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,8 @@ case class Compliance(instance: String,

override protected def additionalPreconditions(): Seq[StructType => Unit] =
columns.map(hasColumn)

// Compliance uses free-form SQL predicates that can reference arbitrary columns,
// so we cannot safely determine which columns are needed.
override def columnsReferenced(): Option[Set[String]] = None
}
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/Correlation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,7 @@ case class Correlation(
}

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(firstColumn, secondColumn))
}
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/DataType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,4 +183,7 @@ case class DataType(
}

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))
}
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/ExactQuantile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ with FilterableAnalyzer {

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))

@VisibleForTesting
private def criterion: Column = conditionalSelection(column, where).cast(DoubleType)
}
2 changes: 2 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/KLLSketch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ case class KLLSketch(
override def preconditions(): Seq[StructType => Unit] = {
PARAM_CHECK :: hasColumn(column) :: isNumeric(column) :: Nil
}

override def columnsReferenced(): Option[Set[String]] = Some(Set(column))
}

object KLLSketch {
Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))

private[deequ] def criterion: Column = {
val isNullCheck = col(column).isNull
val colLength = length(col(column)).cast(DoubleType)
Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/Maximum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))

@VisibleForTesting
private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where)
}
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/Mean.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,7 @@ case class Mean(column: String, where: Option[String] = None)
}

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))
}
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/MinLength.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))

private[deequ] def criterion: Column = {
val isNullCheck = col(column).isNull
val colLength = length(col(column)).cast(DoubleType)
Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/Minimum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))

@VisibleForTesting
private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where)
}
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/PatternMatch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ case class PatternMatch(column: String, pattern: Regex, where: Option[String] =

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))

override protected def additionalPreconditions(): Seq[StructType => Unit] = {
hasColumn(column) :: isString(column) :: Nil
}
Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/RatioOfSums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,7 @@ case class RatioOfSums(
}

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(numerator, denominator))
}
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/Size.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,7 @@ case class Size(where: Option[String] = None)
}

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set.empty)
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,7 @@ case class StandardDeviation(column: String, where: Option[String] = None)
}

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))
}
3 changes: 3 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/Sum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,7 @@ case class Sum(column: String, where: Option[String] = None)
}

override def filterCondition: Option[String] = where

override def columnsReferenced(): Option[Set[String]] =
if (where.isDefined) None else Some(Set(column))
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import com.amazon.deequ.io.DfsUtils
import com.amazon.deequ.metrics.{DoubleMetric, Metric}
import com.amazon.deequ.repository.{MetricsRepository, ResultKey}
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -360,7 +361,9 @@ object AnalysisRunner {
val offsets = shareableAnalyzers.scanLeft(0) { case (current, analyzer) =>
current + analyzer.aggregationFunctions().length
}
val results = data.agg(aggregations.head, aggregations.tail: _*).collect().head

val prunedData = pruneColumns(data, shareableAnalyzers)
val results = prunedData.agg(aggregations.head, aggregations.tail: _*).collect().head
shareableAnalyzers.zip(offsets).map { case (analyzer, offset) =>
analyzer ->
successOrFailureMetricFrom(analyzer, results, offset, aggregateWith, saveStatesTo)
Expand All @@ -382,6 +385,35 @@ object AnalysisRunner {
sharedResults ++ AnalyzerContext(otherMetrics)
}

/**
* Attempts to select only the columns needed by the given analyzers.
* This enables column pruning for V2 DataSource connectors (e.g. Iceberg, Delta Lake)
* which make scan-planning decisions before Spark's optimizer can simplify the plan.
*
* Falls back to the original DataFrame if any analyzer cannot statically declare its columns
* (e.g. analyzers with free-form SQL predicates or WHERE clauses).
*/
private[this] def pruneColumns(
data: DataFrame,
analyzers: Seq[Analyzer[_, _]])
: DataFrame = {

val allColumns = analyzers.map(_.columnsReferenced())

if (allColumns.exists(_.isEmpty)) {
// At least one analyzer cannot declare its columns; skip pruning
data
} else {
val neededColumns = allColumns.flatMap(_.get).distinct
if (neededColumns.isEmpty) {
// All analyzers are dataset-level (e.g. Size), no column selection needed
data
} else {
data.select(neededColumns.map(col): _*)
}
}
}

/** Compute scan-shareable analyzer metric from aggregation result, mapping generic exceptions
* to a failure metric */
private def successOrFailureMetricFrom(
Expand Down
51 changes: 51 additions & 0 deletions src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,57 @@ class AnalyzerTests extends AnyWordSpec with Matchers with SparkContextSpec with
assert(testVal.value.isSuccess)
assert(testVal.value.toOption.get.isInfinite)
}

"return correct columnsReferenced for analyzers without where clauses" in {
assert(Completeness("col1").columnsReferenced() === Some(Set("col1")))
assert(Mean("col1").columnsReferenced() === Some(Set("col1")))
assert(Maximum("col1").columnsReferenced() === Some(Set("col1")))
assert(Minimum("col1").columnsReferenced() === Some(Set("col1")))
assert(Sum("col1").columnsReferenced() === Some(Set("col1")))
assert(StandardDeviation("col1").columnsReferenced() === Some(Set("col1")))
assert(ApproxCountDistinct("col1").columnsReferenced() === Some(Set("col1")))
assert(DataType("col1").columnsReferenced() === Some(Set("col1")))
assert(PatternMatch("col1", ".*".r).columnsReferenced() === Some(Set("col1")))
assert(MaxLength("col1").columnsReferenced() === Some(Set("col1")))
assert(MinLength("col1").columnsReferenced() === Some(Set("col1")))
assert(ExactQuantile("col1", 0.5).columnsReferenced() === Some(Set("col1")))
assert(ApproxQuantile("col1", 0.5).columnsReferenced() === Some(Set("col1")))
assert(ApproxQuantiles("col1", Seq(0.25, 0.75)).columnsReferenced() === Some(Set("col1")))
assert(Correlation("col1", "col2").columnsReferenced() === Some(Set("col1", "col2")))
assert(RatioOfSums("col1", "col2").columnsReferenced() === Some(Set("col1", "col2")))
assert(Size().columnsReferenced() === Some(Set.empty))
assert(ColumnCount().columnsReferenced() === Some(Set.empty))
assert(ColumnExists("col1").columnsReferenced() === Some(Set.empty))
assert(KLLSketch("col1").columnsReferenced() === Some(Set("col1")))
}

"return None for columnsReferenced when where clause is present" in {
assert(Completeness("col1", Some("col2 > 0")).columnsReferenced() === None)
assert(Mean("col1", Some("col2 > 0")).columnsReferenced() === None)
assert(Maximum("col1", Some("col2 > 0")).columnsReferenced() === None)
assert(Minimum("col1", Some("col2 > 0")).columnsReferenced() === None)
assert(Sum("col1", Some("col2 > 0")).columnsReferenced() === None)
assert(StandardDeviation("col1", Some("col2 > 0")).columnsReferenced() === None)
assert(ApproxCountDistinct("col1", Some("col2 > 0")).columnsReferenced() === None)
assert(DataType("col1", Some("col2 > 0")).columnsReferenced() === None)
assert(PatternMatch("col1", ".*".r, Some("col2 > 0")).columnsReferenced() === None)
assert(MaxLength("col1", Some("col2 > 0")).columnsReferenced() === None)
assert(MinLength("col1", Some("col2 > 0")).columnsReferenced() === None)
assert(ExactQuantile("col1", 0.5, Some("col2 > 0")).columnsReferenced() === None)
assert(ApproxQuantile("col1", 0.5, where = Some("col2 > 0")).columnsReferenced() === None)
assert(Correlation("col1", "col2", Some("col1 > 0")).columnsReferenced() === None)
assert(RatioOfSums("col1", "col2", Some("col1 > 0")).columnsReferenced() === None)
assert(Size(Some("col1 > 0")).columnsReferenced() === None)
}

"return None for columnsReferenced for Compliance (free-form SQL)" in {
assert(Compliance("test", "col1 > 0").columnsReferenced() === None)
assert(Compliance("test", "col1 > 0", columns = List("col1")).columnsReferenced() === None)
}

"return None for columnsReferenced for CustomSql" in {
assert(CustomSql("SELECT COUNT(*) FROM table").columnsReferenced() === None)
}
}
}

Expand Down
Loading
Loading