diff --git a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala index 528003da4..519230a28 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala @@ -16,8 +16,10 @@ package com.amazon.deequ.analyzers +import java.math.BigDecimal + import com.amazon.deequ.analyzers.Analyzers._ -import com.amazon.deequ.metrics.{DoubleMetric, Entity, Metric} +import com.amazon.deequ.metrics.{BigDecimalMetric, DoubleMetric, Entity, Metric, DateTimeMetric} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} @@ -52,6 +54,11 @@ trait DoubleValuedState[S <: DoubleValuedState[S]] extends State[S] { def metricValue(): Double } +/** A state which produces a BigDecimalValued metric */ +trait BigDecimalValuedState[S <: BigDecimalValuedState[S]] extends State[S] { + def metricValue(): BigDecimal +} + /** Common trait for all analyzers which generates metrics from states computed on data frames */ trait Analyzer[S <: State[_], +M <: Metric[_]] { @@ -225,6 +232,68 @@ abstract class StandardScanShareableAnalyzer[S <: DoubleValuedState[_]]( } } +/** A scan-shareable analyzer that produces a DateTimeMetric */ +abstract class TimestampScanShareableAnalyzer[S <: DateTimeValuedState[_]]( + name: String, + instance: String, + entity: Entity.Value = Entity.Column) + extends ScanShareableAnalyzer[S, DateTimeMetric] { + + override def computeMetricFrom(state: Option[S]): DateTimeMetric = { + state match { + case Some(theState) => + DateTimeMetric(entity, name, instance, Success(theState.metricValue())) + case _ => + DateTimeMetric(entity, name, instance, Failure( + MetricCalculationException.wrapIfNecessary(emptyStateException(this)))) + } + } + + override private[deequ] def toFailureMetric(exception: Exception): DateTimeMetric = { + DateTimeMetric(entity, name, instance, Failure( + MetricCalculationException.wrapIfNecessary(exception))) + } + + override def preconditions: Seq[StructType => Unit] = { + additionalPreconditions() ++ super.preconditions + } + + protected def additionalPreconditions(): Seq[StructType => Unit] = { + Seq.empty + } +} + +/** A scan-shareable analyzer that produces a BigDecimalMetric */ +abstract class BigDecimalScanShareableAnalyzer[S <: BigDecimalValuedState[_]]( + name: String, + instance: String, + entity: Entity.Value = Entity.Column) + extends ScanShareableAnalyzer[S, BigDecimalMetric] { + + override def computeMetricFrom(state: Option[S]): BigDecimalMetric = { + state match { + case Some(theState) => + BigDecimalMetric(entity, name, instance, Success(theState.metricValue())) + case _ => + BigDecimalMetric(entity, name, instance, Failure( + MetricCalculationException.wrapIfNecessary(emptyStateException(this)))) + } + } + + override private[deequ] def toFailureMetric(exception: Exception): BigDecimalMetric = { + BigDecimalMetric(entity, name, instance, Failure( + MetricCalculationException.wrapIfNecessary(exception))) + } + + override def preconditions: Seq[StructType => Unit] = { + additionalPreconditions() ++ super.preconditions + } + + protected def additionalPreconditions(): Seq[StructType => Unit] = { + Seq.empty + } +} + /** A state for computing ratio-based metrics, * contains #rows that match a predicate and overall #rows */ case class NumMatchesAndCount(numMatches: Long, count: Long) @@ -287,6 +356,9 @@ object Preconditions { private[this] val numericDataTypes = Set(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) + private[this] val dateTypes = + Set(TimestampType, DateType) + private[this] val nestedDataTypes = Set(StructType, MapType, ArrayType) private[this] val caseSensitive = { @@ -304,6 +376,8 @@ object Preconditions { } } + + def hasColumn(column: String, schema: StructType): Boolean = { if (caseSensitive) { schema.fieldNames.contains(column) @@ -380,6 +454,39 @@ object Preconditions { } } + /** Asserts if Specified column is a DateType or TimestampType type throw Exception if not + * @param column for which assertion is performed + * @return + * */ + def isDateType(column: String): StructType => Unit = { schema => + val columnDataType = structField(column, schema).dataType + val hasDateType = columnDataType match { + case TimestampType | DateType => true + case _ => false + } + if (!hasDateType) { + throw new WrongColumnTypeException(s"Expected type of column $column to be one of " + + s"(${dateTypes.mkString(",")}), but found $columnDataType instead!") + } + } + + /** Asserts if Specified column is a Decimal type throw Exception if not + * @param column for which assertion is performed + * @return + * */ + def isDecimalType(column: String): StructType => Unit = { schema => + val columnDataType = structField(column, schema).dataType + val hasNumericType = columnDataType match { + case _ : DecimalType => true + case _ => false + } + + if (!hasNumericType) { + throw new WrongColumnTypeException(s"Expected type of column $column to be one of " + + s"(${numericDataTypes.mkString(",")}), but found $columnDataType instead!") + } + } + /** Specified column has string type */ def isString(column: String): StructType => Unit = { schema => val columnDataType = structField(column, schema).dataType diff --git a/src/main/scala/com/amazon/deequ/analyzers/DateTimeDistribution.scala b/src/main/scala/com/amazon/deequ/analyzers/DateTimeDistribution.scala new file mode 100644 index 000000000..361f32d09 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/DateTimeDistribution.scala @@ -0,0 +1,129 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.analyzers + +import java.sql.Timestamp +import com.amazon.deequ.analyzers.Analyzers._ +import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isDateType} +import com.amazon.deequ.analyzers.runners.MetricCalculationException +import com.amazon.deequ.metrics.{Distribution, DistributionValue, HistogramMetric} +import org.apache.spark.sql.DeequFunctions.dateTimeDistribution +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Column, Row} + +import scala.util.{Failure, Success} + +object DistributionInterval extends Enumeration { + val QUARTER_HOUR, HOURLY, DAILY, WEEKLY, MONTHLY = Value +} + +case class DateTimeDistributionState(distribution: Map[(Timestamp, Timestamp), Long]) + extends State[DateTimeDistributionState] { + + override def sum(other: DateTimeDistributionState): DateTimeDistributionState = { + + DateTimeDistributionState(distribution ++ other.distribution.map { + case (k, v) => k -> (v + distribution.getOrElse(k, 0L)) + }) + } +} + +object DateTimeDistributionState { + + def computeStateFromResult(result: Map[Long, Long], frequency: Long): + Map[(Timestamp, Timestamp), Long] = { + result.map({ + case (x, y) => (new Timestamp(x), new Timestamp(x + frequency - 1L)) -> y + }) + } + + def toDistribution(histogram: DateTimeDistributionState): Distribution = { + val totalCount = histogram.distribution.foldLeft(0L)(_ + _._2) + Distribution(histogram.distribution.map({ + case (x, y) => ("(" + x._1.toString + " to " + x._2.toString + ")") -> + DistributionValue(y, y.toDouble / totalCount) + }), totalCount) + } +} + +/** + * + * @param column : column on which distribution analysis is to be performed + * @param interval : interval of the distribution; + * @param where : optional filter condition + */ +case class DateTimeDistribution( + column: String, + interval: DistributionInterval.Value, + where: Option[String] = None) + extends ScanShareableAnalyzer[DateTimeDistributionState, HistogramMetric] + with FilterableAnalyzer { + + /** Defines the aggregations to compute on the data */ + override private[deequ] def aggregationFunctions(): Seq[Column] = { + dateTimeDistribution(conditionalSelection(column, where), + DateTimeDistribution.getDateTimeAggIntervalValue(interval)) :: Nil + } + + /** Computes the state from the result of the aggregation functions */ + override private[deequ] def fromAggregationResult(result: Row, offset: Int): + Option[DateTimeDistributionState] = { + ifNoNullsIn(result, offset) { _ => + DateTimeDistributionState( + DateTimeDistributionState.computeStateFromResult(Map.empty[Long, Long] ++ result.getMap(0), + DateTimeDistribution.getDateTimeAggIntervalValue(interval))) + } + } + + override def preconditions: Seq[StructType => Unit] = { + hasColumn(column) +: isDateType(column) +: super.preconditions + } + + override def filterCondition: Option[String] = where + + /** + * Compute the metric from the state (sufficient statistics) + * + * @param state wrapper holding a state of type S (required due to typing issues...) + * @return + */ + override def computeMetricFrom(state: Option[DateTimeDistributionState]): HistogramMetric = { + state match { + case Some(histogram) => + HistogramMetric(column, Success(DateTimeDistributionState.toDistribution(histogram))) + case _ => + toFailureMetric(emptyStateException(this)) + } + } + + override private[deequ] def toFailureMetric(failure: Exception): HistogramMetric = { + HistogramMetric(column, Failure(MetricCalculationException.wrapIfNecessary(failure))) + } +} + +object DateTimeDistribution { + + def getDateTimeAggIntervalValue(interval: DistributionInterval.Value): Long = { + interval match { + case DistributionInterval.QUARTER_HOUR => 900000L // 15 Minutes + case DistributionInterval.HOURLY => 3600000L // 60 Minutes + case DistributionInterval.DAILY => 86400000L // 24 Hours + case _ => 604800000L // 7 * 24 Hours + } + } + +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala index cb4532e38..30b45d825 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Maximum.scala @@ -16,9 +16,11 @@ package com.amazon.deequ.analyzers -import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} +import java.math.BigDecimal + +import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isDecimalType, isNumeric} import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.functions.max +import org.apache.spark.sql.functions.{max, min} import org.apache.spark.sql.types.{DoubleType, StructType} import Analyzers._ @@ -54,3 +56,37 @@ case class Maximum(column: String, where: Option[String] = None) override def filterCondition: Option[String] = where } + +case class MaxBigDecimalState(minValue: BigDecimal) + extends BigDecimalValuedState[MaxBigDecimalState] { + + override def sum(other: MaxBigDecimalState): MaxBigDecimalState = { + MaxBigDecimalState(minValue.max(other.minValue)) + } + + override def metricValue(): BigDecimal = { + minValue + } +} + +case class MaximumBigDecimal(column: String, where: Option[String] = None) + extends BigDecimalScanShareableAnalyzer[MaxBigDecimalState]("Maximum BigDecimal", column) + with FilterableAnalyzer { + + override def aggregationFunctions(): Seq[Column] = { + max(conditionalSelection(column, where)) :: Nil + } + + + override def fromAggregationResult(result: Row, offset: Int): Option[MaxBigDecimalState] = { + ifNoNullsIn(result, offset) { _ => + MaxBigDecimalState(result.getDecimal(offset)) + } + } + + override protected def additionalPreconditions(): Seq[StructType => Unit] = { + hasColumn(column) :: isDecimalType(column) :: Nil + } + + override def filterCondition: Option[String] = where +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/MaximumDateTime.scala b/src/main/scala/com/amazon/deequ/analyzers/MaximumDateTime.scala new file mode 100644 index 000000000..b27b3d759 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/MaximumDateTime.scala @@ -0,0 +1,56 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.analyzers + +import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isDateType} +import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.functions.max +import org.apache.spark.sql.types.{TimestampType, StructType} +import Analyzers._ +import java.sql.Timestamp + +case class MaxDateTimeState(maxValue: Timestamp) extends DateTimeValuedState[MaxDateTimeState] { + + override def sum(other: MaxDateTimeState): MaxDateTimeState = { + MaxDateTimeState(if (maxValue.compareTo(other.maxValue) > 0) maxValue else other.maxValue) + } + + override def metricValue(): Timestamp = { + maxValue + } +} + +case class MaximumDateTime(column: String, where: Option[String] = None) + extends TimestampScanShareableAnalyzer[MaxDateTimeState]("Maximum Date Time", column) + with FilterableAnalyzer { + + override def aggregationFunctions(): Seq[Column] = { + max(conditionalSelection(column, where)).cast(TimestampType) :: Nil + } + + override def fromAggregationResult(result: Row, offset: Int): Option[MaxDateTimeState] = { + ifNoNullsIn(result, offset) { _ => + MaxDateTimeState(result.getTimestamp(offset)) + } + } + + override protected def additionalPreconditions(): Seq[StructType => Unit] = { + hasColumn(column) :: isDateType(column) :: Nil + } + + override def filterCondition: Option[String] = where +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/Mean.scala b/src/main/scala/com/amazon/deequ/analyzers/Mean.scala index d9afbd227..b0451cf67 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Mean.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Mean.scala @@ -16,10 +16,12 @@ package com.amazon.deequ.analyzers -import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} +import java.math.BigDecimal + +import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isDecimalType, isNumeric} import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.functions.{count, sum} -import org.apache.spark.sql.types.{DoubleType, StructType, LongType} +import org.apache.spark.sql.types.{DoubleType, LongType, StructType} import Analyzers._ case class MeanState(sum: Double, count: Long) extends DoubleValuedState[MeanState] { @@ -55,3 +57,38 @@ case class Mean(column: String, where: Option[String] = None) override def filterCondition: Option[String] = where } + +case class BigDecimalMeanState(sum: BigDecimal, count: Long) + extends BigDecimalValuedState[BigDecimalMeanState] { + + override def sum(other: BigDecimalMeanState): BigDecimalMeanState = { + BigDecimalMeanState(sum.add(other.sum), count + other.count) + } + + override def metricValue(): BigDecimal = { + if (count == 0L) null else sum.divide(new BigDecimal(count)) + } +} + +case class BigDecimalMean(column: String, where: Option[String] = None) + extends BigDecimalScanShareableAnalyzer[BigDecimalMeanState]("BigDecimal Mean", column) + with FilterableAnalyzer { + + override def aggregationFunctions(): Seq[Column] = { + sum(conditionalSelection(column, where)) :: + count(conditionalSelection(column, where)).cast(LongType) :: Nil + } + + override def fromAggregationResult(result: Row, offset: Int): Option[BigDecimalMeanState] = { + + ifNoNullsIn(result, offset, howMany = 2) { _ => + BigDecimalMeanState(result.getDecimal(offset), result.getLong(offset + 1)) + } + } + + override protected def additionalPreconditions(): Seq[StructType => Unit] = { + hasColumn(column) :: isDecimalType(column) :: Nil + } + + override def filterCondition: Option[String] = where +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala index 68ef926cf..cd43ae3a5 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Minimum.scala @@ -16,7 +16,9 @@ package com.amazon.deequ.analyzers -import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} +import java.math.BigDecimal + +import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isDecimalType, isNumeric} import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.functions.min import org.apache.spark.sql.types.{DoubleType, StructType} @@ -54,3 +56,37 @@ case class Minimum(column: String, where: Option[String] = None) override def filterCondition: Option[String] = where } + +case class MinBigDecimalState(minValue: BigDecimal) + extends BigDecimalValuedState[MinBigDecimalState] { + + override def sum(other: MinBigDecimalState): MinBigDecimalState = { + MinBigDecimalState(minValue.min(other.minValue)) + } + + override def metricValue(): BigDecimal = { + minValue + } +} + +case class MinimumBigDecimal(column: String, where: Option[String] = None) + extends BigDecimalScanShareableAnalyzer[MinBigDecimalState]("Minimum BigDecimal", column) + with FilterableAnalyzer { + + override def aggregationFunctions(): Seq[Column] = { + min(conditionalSelection(column, where)) :: Nil + } + + + override def fromAggregationResult(result: Row, offset: Int): Option[MinBigDecimalState] = { + ifNoNullsIn(result, offset) { _ => + MinBigDecimalState(result.getDecimal(offset)) + } + } + + override protected def additionalPreconditions(): Seq[StructType => Unit] = { + hasColumn(column) :: isDecimalType(column) :: Nil + } + + override def filterCondition: Option[String] = where +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/MinimumDateTime.scala b/src/main/scala/com/amazon/deequ/analyzers/MinimumDateTime.scala new file mode 100644 index 000000000..e1a0f2b39 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/MinimumDateTime.scala @@ -0,0 +1,60 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.analyzers + +import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isDateType} +import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.functions.min +import org.apache.spark.sql.types.{TimestampType, StructType} +import Analyzers._ +import java.sql.Timestamp + +trait DateTimeValuedState[S <: DateTimeValuedState[S]] extends State[S] { + def metricValue(): Timestamp +} + +case class MinDateTimeState(minValue: Timestamp) extends DateTimeValuedState[MinDateTimeState] { + + override def sum(other: MinDateTimeState): MinDateTimeState = { + MinDateTimeState(if (minValue.compareTo(other.minValue) < 0) minValue else other.minValue) + } + + override def metricValue(): Timestamp = { + minValue + } +} + +case class MinimumDateTime(column: String, where: Option[String] = None) + extends TimestampScanShareableAnalyzer[MinDateTimeState]("Minimum Date Time", column) + with FilterableAnalyzer { + + override def aggregationFunctions(): Seq[Column] = { + min(conditionalSelection(column, where)).cast(TimestampType) :: Nil + } + + override def fromAggregationResult(result: Row, offset: Int): Option[MinDateTimeState] = { + ifNoNullsIn(result, offset) { _ => + MinDateTimeState(result.getTimestamp(offset)) + } + } + + override protected def additionalPreconditions(): Seq[StructType => Unit] = { + hasColumn(column) :: isDateType(column) :: Nil + } + + override def filterCondition: Option[String] = where +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/Sum.scala b/src/main/scala/com/amazon/deequ/analyzers/Sum.scala index 535c14209..8127561fe 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Sum.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Sum.scala @@ -16,7 +16,9 @@ package com.amazon.deequ.analyzers -import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric} +import java.math.BigDecimal + +import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isDecimalType, isNumeric} import org.apache.spark.sql.functions.sum import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.sql.{Column, Row} @@ -53,3 +55,35 @@ case class Sum(column: String, where: Option[String] = None) override def filterCondition: Option[String] = where } + +case class BigDecimalSumState(sum: BigDecimal) extends BigDecimalValuedState[BigDecimalSumState] { + + override def sum(other: BigDecimalSumState): BigDecimalSumState = { + BigDecimalSumState(sum.add(other.sum)) + } + + override def metricValue(): BigDecimal = { + sum + } +} + +case class BigDecimalSum(column: String, where: Option[String] = None) + extends BigDecimalScanShareableAnalyzer[BigDecimalSumState]("BigDecimal Sum", column) + with FilterableAnalyzer { + + override def aggregationFunctions(): Seq[Column] = { + sum(conditionalSelection(column, where)) :: Nil + } + + override def fromAggregationResult(result: Row, offset: Int): Option[BigDecimalSumState] = { + ifNoNullsIn(result, offset) { _ => + BigDecimalSumState(result.getDecimal(offset)) + } + } + + override protected def additionalPreconditions(): Seq[StructType => Unit] = { + hasColumn(column) :: isDecimalType(column) :: Nil + } + + override def filterCondition: Option[String] = where +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/catalyst/DateTimeAggregation.scala b/src/main/scala/com/amazon/deequ/analyzers/catalyst/DateTimeAggregation.scala new file mode 100644 index 000000000..2290ecf5a --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/catalyst/DateTimeAggregation.scala @@ -0,0 +1,60 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.analyzers.catalyst + +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + + +class DateTimeAggregation(frequency: Long) extends UserDefinedAggregateFunction { + + override def inputSchema: StructType = StructType(StructField("value", TimestampType) :: Nil) + + override def bufferSchema: StructType = StructType(StructField("map", + DataTypes.createMapType(LongType, LongType)) :: Nil) + + override def dataType: DataType = DataTypes.createMapType(LongType, LongType) + + override def deterministic: Boolean = true + + override def initialize(buffer: MutableAggregationBuffer): Unit = { + val map = Map.empty[Long, Long] + buffer.update(0, map) + } + + override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!input.isNullAt(0)) { + val datetime = input.getTimestamp(0).getTime + val batchTime = datetime - (datetime % frequency) + val bufferMap = buffer(0).asInstanceOf[Map[Long, Long]] + buffer(0) = bufferMap + (batchTime -> (bufferMap.getOrElse(batchTime, 0L) + 1L)) + } + } + + override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + val bufferMap1 = buffer1(0).asInstanceOf[Map[Long, Long]] + val bufferMap2 = buffer2(0).asInstanceOf[Map[Long, Long]] + buffer1(0) = bufferMap1 ++ bufferMap2.map { + case (k, v) => k -> (v + bufferMap1.getOrElse(k, 0L)) + } + } + + override def evaluate(buffer: Row): Any = { + buffer.getMap(0) + } +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/catalyst/DeequFunctions.scala b/src/main/scala/com/amazon/deequ/analyzers/catalyst/DeequFunctions.scala index e19448d76..8efbf1ccf 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/catalyst/DeequFunctions.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/catalyst/DeequFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import com.amazon.deequ.analyzers.KLLSketch +import com.amazon.deequ.analyzers.catalyst.DateTimeAggregation import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, StatefulApproxQuantile, StatefulHyperloglogPlus} import org.apache.spark.sql.catalyst.expressions.Literal @@ -81,6 +81,17 @@ object DeequFunctions { statefulDataType(column) } + /** + * return DataTime distribution aggregation function + * @param column: column on which aggragtion to be performed + * @param interval: interval of date time aggregation + * @return Column: aggregation function Column + * */ + def dateTimeDistribution(column: Column, interval: Long): Column = { + val dateTimeAgg = new DateTimeAggregation(interval) + dateTimeAgg(column) + } + def stateful_kll( column: Column, sketchSize: Int, diff --git a/src/main/scala/com/amazon/deequ/checks/Check.scala b/src/main/scala/com/amazon/deequ/checks/Check.scala index c71652f14..c78ebff75 100644 --- a/src/main/scala/com/amazon/deequ/checks/Check.scala +++ b/src/main/scala/com/amazon/deequ/checks/Check.scala @@ -16,16 +16,18 @@ package com.amazon.deequ.checks +import java.sql.Timestamp + import com.amazon.deequ.anomalydetection.{AnomalyDetectionStrategy, AnomalyDetector, DataPoint} import com.amazon.deequ.analyzers.runners.AnalyzerContext -import com.amazon.deequ.analyzers.{Analyzer, Histogram, Patterns, State, KLLParameters} -import com.amazon.deequ.constraints.Constraint._ +import com.amazon.deequ.analyzers.{Analyzer, Histogram, KLLParameters, Patterns, State} +import com.amazon.deequ.constraints.Constraint.{maxTimestampConstraint, minTimestampConstraint, _} import com.amazon.deequ.constraints._ import com.amazon.deequ.metrics.{BucketDistribution, Distribution, Metric} import com.amazon.deequ.repository.MetricsRepository import org.apache.spark.sql.expressions.UserDefinedFunction import com.amazon.deequ.anomalydetection.HistoryUtils -import com.amazon.deequ.checks.ColumnCondition.{isEachNotNull, isAnyNotNull} +import com.amazon.deequ.checks.ColumnCondition.{isAnyNotNull, isEachNotNull} import scala.util.matching.Regex @@ -924,6 +926,22 @@ case class Check( isContainedIn(column, allowedValues, Check.IsOne, None) } + /** + * Asserts that every non-null value in a column is contained in a set of predefined values + *for providing isContained availble for numeric types as well + * + * @param column Column to run the assertion on + * @param allowedValues allowed values for the column + * @return + */ + def isContainedIn[T <: AnyVal]( + column: String, + allowedValues: Array[T]) + : CheckWithLastConstraintFilterable = { + + isContainedIn(column, allowedValues, Check.IsOne, None) + } + // We can't use default values here as you can't combine default values and overloading in Scala /** * Asserts that every non-null value in a column is contained in a set of predefined values @@ -942,6 +960,15 @@ case class Check( isContainedIn(column, allowedValues, Check.IsOne, hint) } + def isContainedIn[T <: AnyVal]( + column: String, + allowedValues: Array[T], + hint: Option[String]) + : CheckWithLastConstraintFilterable = { + + isContainedIn(column, allowedValues, Check.IsOne, hint) + } + // We can't use default values here as you can't combine default values and overloading in Scala /** * Asserts that every non-null value in a column is contained in a set of predefined values @@ -971,22 +998,32 @@ case class Check( * @param hint A hint to provide additional context why a constraint could have failed * @return */ - def isContainedIn( + def isContainedIn[T]( column: String, - allowedValues: Array[String], + allowedValues: Array[T], assertion: Double => Boolean, hint: Option[String]) : CheckWithLastConstraintFilterable = { - - val valueList = allowedValues - .map { _.replaceAll("'", "''") } - .mkString("'", "','", "'") + val valueList = getValueList(allowedValues) val predicate = s"`$column` IS NULL OR `$column` IN ($valueList)" satisfies(predicate, s"$column contained in ${allowedValues.mkString(",")}", assertion, hint) } + + def getValueList[T](allowedValues: Array[_]): String = { + allowedValues match { + case allowedValues : Array[String] => allowedValues + .map { + _.replaceAll("'", "''") + } + .mkString("'", "','", "'") + case allowedValues : Array[Char] => allowedValues.mkString("'", "','", "'") + case _ => allowedValues.mkString(",") + } + } + /** * Asserts that the non-null values in a numeric column fall into the predefined interval * @@ -1013,7 +1050,111 @@ case class Check( val predicate = s"`$column` IS NULL OR " + s"(`$column` $leftOperand $lowerBound AND `$column` $rightOperand $upperBound)" - satisfies(predicate, s"$column between $lowerBound and $upperBound", hint = hint) + satisfies(predicate, s"`$column` between $lowerBound and $upperBound", hint = hint) + } + + def hasMinTimestamp( + column: String, + assertion: Timestamp => Boolean, + hint: Option[String] = None) + : CheckWithLastConstraintFilterable = { + + addFilterableConstraint { filter => minTimestampConstraint(column, assertion, filter, hint) } + } + + def hasMaxTimestamp( + column: String, + assertion: Timestamp => Boolean, + hint: Option[String] = None) + : CheckWithLastConstraintFilterable = { + + addFilterableConstraint { filter => maxTimestampConstraint(column, assertion, filter, hint) } + } + + /** + * Asserts that, in each row, the value of column (DateType or TimestampType) + * is less than the given datetime (Timestamp) + * + * @param column Column to run the assertion on + * @param datetime value of Timestamp to run assert + * @param assertion Function that receives a Timestamp input parameter and returns a boolean + * @param hint A hint to provide additional context why a constraint could have failed + * @return + */ + def isDateTimeLessThan( + column: String, + datetime: Timestamp, + assertion: Double => Boolean = Check.IsOne, + hint: Option[String] = None) + : CheckWithLastConstraintFilterable = { + + satisfies(s"$column < to_timestamp('${datetime.toString}')", + s"$column is less than '${datetime.toString}'", assertion, + hint = hint) + } + + /** + * + * Asserts that, in each row, the value of column (DateType or TimestampType) + * is greater than the given datetime (Timestamp) + * + * @param column Column to run the assertion on + * @param datetime value of Timestamp to run assert + * @param assertion Function that receives a Timestamp input parameter and returns a boolean + * @param hint A hint to provide additional context why a constraint could have failed + * @return + */ + def isDateTimeGreaterThan( + column: String, + datetime: Timestamp, + assertion: Double => Boolean = Check.IsOne, + hint: Option[String] = None) + : CheckWithLastConstraintFilterable = { + + satisfies(s"$column > to_timestamp('${datetime.toString}')", + s"$column is greater than '${datetime.toString}'", assertion, + hint = hint) + } + + /** + * + * Asserts that, in each row, the value of column (DateType or TimestampType) contains a past date + * + * @param column Column to run the assertion on + * @param assertion Function that receives a Timestamp input parameter and returns a boolean + * @param hint A hint to provide additional context why a constraint could have failed + * @return + */ + def hasPastDates( + column: String, + assertion: Double => Boolean = Check.IsOne, + hint: Option[String] = None) + : CheckWithLastConstraintFilterable = { + + satisfies(s"$column < now()", + s"$column has all past dates", assertion, + hint = hint) + } + + /** + * + * Asserts that, in each row, the value of column (DateType or TimestampType) + * contains a future date + * + * @param column Column to run the assertion on + * @param assertion Function that receives a Timestamp input parameter and returns a boolean + * @param hint A hint to provide additional context why a constraint could have failed + * @return + */ + def hasFutureDates( + column: String, + assertion: Double => Boolean = Check.IsOne, + hint: Option[String] = None) + : CheckWithLastConstraintFilterable = { + + satisfies(s"$column > now()", + s"$column has all future dates", assertion, + hint = hint) } /** @@ -1127,4 +1268,5 @@ object Check { detectedAnomalies.anomalies.isEmpty } + } diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index c7963ce41..0ef06dfc1 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -16,7 +16,9 @@ package com.amazon.deequ.constraints -import com.amazon.deequ.analyzers._ +import java.sql.Timestamp + +import com.amazon.deequ.analyzers.{MaximumDateTime, _} import com.amazon.deequ.metrics.{BucketDistribution, Distribution, Metric} import org.apache.spark.sql.expressions.UserDefinedFunction @@ -484,6 +486,35 @@ object Constraint { new NamedConstraint(constraint, s"MaximumConstraint($maximum)") } + def minTimestampConstraint( + column: String, + assertion: Timestamp => Boolean, + where: Option[String] = None, + hint: Option[String] = None) + : Constraint = { + + val minimum = MinimumDateTime(column, where) + + val constraint = AnalysisBasedConstraint[MinDateTimeState, Timestamp, + Timestamp](minimum, assertion, hint = hint) + + new NamedConstraint(constraint, s"MinimumTimestampConstraint($minimum)") + } + + def maxTimestampConstraint( + column: String, + assertion: Timestamp => Boolean, + where: Option[String] = None, + hint: Option[String] = None) + : Constraint = { + + val maximum = MaximumDateTime(column, where) + + val constraint = AnalysisBasedConstraint[MaxDateTimeState, Timestamp, + Timestamp](maximum, assertion, hint = hint) + + new NamedConstraint(constraint, s"MaximumTimestampConstraint($maximum)") + } /** * Runs mean analysis on the given column and executes the assertion * diff --git a/src/main/scala/com/amazon/deequ/examples/AnalyzerExample.scala b/src/main/scala/com/amazon/deequ/examples/AnalyzerExample.scala new file mode 100644 index 000000000..e3813d5c1 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/examples/AnalyzerExample.scala @@ -0,0 +1,58 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.examples + +import java.math.BigDecimal +import java.sql.Timestamp + + +import com.amazon.deequ.analyzers.runners.{AnalysisRunner, AnalyzerContext} +import com.amazon.deequ.analyzers.{BigDecimalMean, BigDecimalSum, DateTimeDistribution, DistributionInterval, MaximumDateTime, Mean, MinimumDateTime, Sum} +import com.amazon.deequ.examples.ExampleUtils.{ordersAsDataframe, withSpark} +import com.amazon.deequ.analyzers.runners.AnalyzerContext.successMetricsAsDataFrame + +private[examples] object AnalyzerExample extends App { + withSpark { session => + + val data = ordersAsDataframe(session, + Order(1, new BigDecimal(213.2132), Timestamp.valueOf("2020-02-15 07:15:00")), + Order(2, new BigDecimal(43.21324432876), Timestamp.valueOf("2020-02-15 07:45:00.999")), + Order(3, new BigDecimal(56.8881238823888), Timestamp.valueOf("2020-02-15 08:15:49.786")), + Order(4, new BigDecimal(101.2324434978788), Timestamp.valueOf("2020-02-15 12:15:00")), + Order(5, new BigDecimal(723.234324234324324324), Timestamp.valueOf("2020-02-15 15:14:23.678")) + ) + + val analysisResult: AnalyzerContext = { AnalysisRunner + .onData(data) + .addAnalyzer(DateTimeDistribution("orderDate", DistributionInterval.HOURLY)) + .addAnalyzer(MinimumDateTime("orderDate")) + .addAnalyzer(MaximumDateTime("orderDate")) + .addAnalyzer(Sum("amount")) + .addAnalyzer(BigDecimalSum("amount")) + .addAnalyzer(Mean("amount")) + .addAnalyzer(BigDecimalMean("amount")) + .run() + } + + successMetricsAsDataFrame(session, analysisResult).show(false) + + analysisResult.metricMap.foreach( x => + println(s"column '${x._2.instance}' has ${x._2.name} : ${x._2.value.get}") + ) + + } +} diff --git a/src/main/scala/com/amazon/deequ/examples/ExampleUtils.scala b/src/main/scala/com/amazon/deequ/examples/ExampleUtils.scala index 699711a5d..44f305be7 100644 --- a/src/main/scala/com/amazon/deequ/examples/ExampleUtils.scala +++ b/src/main/scala/com/amazon/deequ/examples/ExampleUtils.scala @@ -45,4 +45,9 @@ private[deequ] object ExampleUtils { val rdd = session.sparkContext.parallelize(manufacturers) session.createDataFrame(rdd) } + + def ordersAsDataframe(session: SparkSession, orders: Order*): DataFrame = { + val rdd = session.sparkContext.parallelize(orders) + session.createDataFrame(rdd) + } } diff --git a/src/main/scala/com/amazon/deequ/examples/entities.scala b/src/main/scala/com/amazon/deequ/examples/entities.scala index f2750ecfe..ebd81eb64 100644 --- a/src/main/scala/com/amazon/deequ/examples/entities.scala +++ b/src/main/scala/com/amazon/deequ/examples/entities.scala @@ -16,6 +16,9 @@ package com.amazon.deequ.examples +import java.sql.Timestamp +import java.math.BigDecimal + private[deequ] case class Item( id: Long, productName: String, @@ -24,6 +27,12 @@ private[deequ] case class Item( numViews: Long ) +private[deequ] case class Order( + id: Long, + amount: BigDecimal, + orderDate: Timestamp +) + private[deequ] case class Manufacturer( id: Long, manufacturerName: String, diff --git a/src/main/scala/com/amazon/deequ/metrics/Metric.scala b/src/main/scala/com/amazon/deequ/metrics/Metric.scala index 0964d160a..d5036d535 100644 --- a/src/main/scala/com/amazon/deequ/metrics/Metric.scala +++ b/src/main/scala/com/amazon/deequ/metrics/Metric.scala @@ -16,6 +16,8 @@ package com.amazon.deequ.metrics +import java.sql.Timestamp + import scala.util.{Failure, Success, Try} object Entity extends Enumeration { @@ -66,3 +68,36 @@ case class KeyedDoubleMetric( } } } + +case class DateTimeMetric( + entity: Entity.Value, + name: String, + instance: String, + value: Try[Timestamp]) + extends Metric[Timestamp] { + + override def flatten(): Seq[DoubleMetric] = { + if (value.isSuccess) { + Seq(DoubleMetric(entity, "Timestamp milliseconds", instance, + Success(value.get.getTime.toDouble))) + } else { + Seq(DoubleMetric(entity, "Timestamp milliseconds", instance, Failure(value.failed.get))) + } + } +} + +case class BigDecimalMetric( + entity: Entity.Value, + name: String, + instance: String, + value: Try[BigDecimal]) + extends Metric[BigDecimal] { + + override def flatten(): Seq[DoubleMetric] = { + if (value.isSuccess) { + Seq(DoubleMetric(entity, "BigDecimal", instance, Success(value.get.toDouble))) + } else { + Seq(DoubleMetric(entity, "BigDecimal", instance, Failure(value.failed.get))) + } + } +}