Skip to content

Commit 07e4b4f

Browse files
authored
#314: Add support for sum of truncated values (#324)
-Add support for sum of truncated values. -Added the aggregatedTruncTotal and absAggregatedTruncTotal Measures. Closes #314 Release notes: -Added the aggregatedTruncTotal and absAggregatedTruncTotal Measures. -Added the tests for these Measures.
1 parent 14e90a5 commit 07e4b4f

File tree

5 files changed

+149
-25
lines changed

5 files changed

+149
-25
lines changed

agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,6 @@ final case class UnknownMeasure(measureName: String, measuredColumns: Seq[String
3939

4040
object AtumMeasure {
4141

42-
val supportedMeasureNames: Seq[String] = Seq(
43-
RecordCount.measureName,
44-
DistinctRecordCount.measureName,
45-
SumOfValuesOfColumn.measureName,
46-
AbsSumOfValuesOfColumn.measureName,
47-
SumOfHashesOfColumn.measureName
48-
)
49-
5042
case class RecordCount private (measureName: String) extends AtumMeasure {
5143
private val columnExpression = count("*")
5244

@@ -117,6 +109,42 @@ object AtumMeasure {
117109
def apply(measuredCol: String): AbsSumOfValuesOfColumn = AbsSumOfValuesOfColumn(measureName, measuredCol)
118110
}
119111

112+
case class SumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {
113+
114+
private val columnAggFn: Column => Column = column => sum(when(column >= 0, floor(column)).otherwise(ceil(column)))
115+
116+
override def function: MeasurementFunction = (ds: DataFrame) => {
117+
val dataType = ds.select(measuredCol).schema.fields(0).dataType
118+
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect()
119+
MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
120+
}
121+
122+
override def measuredColumns: Seq[String] = Seq(measuredCol)
123+
override val resultValueType: ResultValueType = ResultValueType.LongValue
124+
}
125+
object SumOfTruncatedValuesOfColumn {
126+
private[agent] val measureName: String = "aggregatedTruncTotal"
127+
def apply(measuredCol: String): SumOfTruncatedValuesOfColumn = SumOfTruncatedValuesOfColumn(measureName, measuredCol)
128+
}
129+
130+
case class AbsSumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {
131+
132+
private val columnAggFn: Column => Column = column => sum(abs(when(column >= 0, floor(column)).otherwise(ceil(column))))
133+
134+
override def function: MeasurementFunction = (ds: DataFrame) => {
135+
val dataType = ds.select(measuredCol).schema.fields(0).dataType
136+
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect()
137+
MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
138+
}
139+
140+
override def measuredColumns: Seq[String] = Seq(measuredCol)
141+
override val resultValueType: ResultValueType = ResultValueType.LongValue
142+
}
143+
object AbsSumOfTruncatedValuesOfColumn {
144+
private[agent] val measureName: String = "absAggregatedTruncTotal"
145+
def apply(measuredCol: String): AbsSumOfTruncatedValuesOfColumn = AbsSumOfTruncatedValuesOfColumn(measureName, measuredCol)
146+
}
147+
120148
case class SumOfHashesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {
121149
private val columnExpression: Column = sum(crc32(col(measuredCol).cast("String")))
122150
override def function: MeasurementFunction = (ds: DataFrame) => {

agent/src/main/scala/za/co/absa/atum/agent/model/MeasuresBuilder.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ private [agent] object MeasuresBuilder extends Logging {
4949
case DistinctRecordCount.measureName => DistinctRecordCount(measuredColumns)
5050
case SumOfValuesOfColumn.measureName => SumOfValuesOfColumn(measuredColumns.head)
5151
case AbsSumOfValuesOfColumn.measureName => AbsSumOfValuesOfColumn(measuredColumns.head)
52+
case SumOfTruncatedValuesOfColumn.measureName => SumOfTruncatedValuesOfColumn(measuredColumns.head)
53+
case AbsSumOfTruncatedValuesOfColumn.measureName => AbsSumOfTruncatedValuesOfColumn(measuredColumns.head)
5254
case SumOfHashesOfColumn.measureName => SumOfHashesOfColumn(measuredColumns.head)
5355
}
5456
}.toOption

agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
3232
"Measure" should "be based on the dataframe" in {
3333

3434
// Measures
35-
val measureIds: AtumMeasure = RecordCount()
36-
val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn(
37-
measuredCol = "salary"
38-
)
39-
val salarySum = SumOfValuesOfColumn(measuredCol = "salary")
40-
val sumOfHashes: AtumMeasure = SumOfHashesOfColumn(measuredCol = "id")
35+
val measureIds: AtumMeasure = RecordCount()
36+
val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn(measuredCol = "salary")
37+
val salarySum = SumOfValuesOfColumn(measuredCol = "salary")
38+
val salaryTruncSum = SumOfTruncatedValuesOfColumn(measuredCol = "salary")
39+
val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn(measuredCol = "salary")
40+
val sumOfHashes: AtumMeasure = SumOfHashesOfColumn(measuredCol = "id")
4141

4242
// AtumContext contains `Measurement`
4343
val atumContextInstanceWithRecordCount = AtumAgent
@@ -86,12 +86,34 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
8686
.removeMeasure(salaryAbsSum)
8787
)
8888

89+
val dfExtraPersonWithDecimalSalary = spark
90+
.createDataFrame(
91+
Seq(
92+
("id", "firstName", "lastName", "email", "email2", "profession", "3000.98"),
93+
("id", "firstName", "lastName", "email", "email2", "profession", "-1000.76")
94+
)
95+
)
96+
.toDF("id", "firstName", "lastName", "email", "email2", "profession", "salary")
97+
98+
val dfExtraDecimalPerson = dfExtraPersonWithDecimalSalary.union(dfPersons)
99+
100+
dfExtraDecimalPerson.createCheckpoint("a checkpoint name")(
101+
atumContextWithSalaryAbsMeasure
102+
.removeMeasure(measureIds)
103+
.removeMeasure(salaryAbsSum)
104+
)
105+
106+
89107
val dfPersonCntResult = measureIds.function(dfPersons)
90108
val dfFullCntResult = measureIds.function(dfFull)
91109
val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull)
92110
val dfFullHashResult = sumOfHashes.function(dfFull)
93111
val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson)
94112
val dfFullSalarySumResult = salarySum.function(dfFull)
113+
val dfExtraPersonSalarySumTruncResult = salaryTruncSum.function(dfExtraDecimalPerson)
114+
val dfFullSalarySumTruncResult = salaryTruncSum.function(dfFull)
115+
val dfExtraPersonSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfExtraDecimalPerson)
116+
val dfFullSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfFull)
95117

96118
// Assertions
97119
assert(dfPersonCntResult.resultValue == "1000")
@@ -106,6 +128,14 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
106128
assert(dfExtraPersonSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
107129
assert(dfFullSalarySumResult.resultValue == "2987144")
108130
assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
131+
assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144")
132+
assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.LongValue)
133+
assert(dfFullSalarySumTruncResult.resultValue == "2987144")
134+
assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.LongValue)
135+
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144")
136+
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue)
137+
assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144")
138+
assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue)
109139
}
110140

111141
"AbsSumOfValuesOfColumn" should "return expected value" in {
@@ -187,4 +217,33 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
187217
assert(result.resultValueType == ResultValueType.BigDecimalValue)
188218
}
189219

220+
"SumTruncOfValuesOfColumn" should "return expected value" in {
221+
val distinctCount = SumOfTruncatedValuesOfColumn("colA")
222+
223+
val data = List(Row("1.98", "b1"), Row("-1.76", "b2"), Row("1.54", "b2"), Row("1.32", "b2"))
224+
val rdd = spark.sparkContext.parallelize(data)
225+
226+
val schema = StructType(Array(StructField("colA", StringType), StructField("colB", StringType)))
227+
val df = spark.createDataFrame(rdd, schema)
228+
229+
val result = distinctCount.function(df)
230+
231+
assert(result.resultValue == "2")
232+
assert(result.resultValueType == ResultValueType.LongValue)
233+
}
234+
235+
"AbsSumTruncOfValuesOfColumn" should "return expected value" in {
236+
val distinctCount = AbsSumOfTruncatedValuesOfColumn("colA")
237+
238+
val data = List(Row("1.98", "b1"), Row("-1.76", "b2"), Row("1.54", "b2"), Row("-1.32", "b2"))
239+
val rdd = spark.sparkContext.parallelize(data)
240+
241+
val schema = StructType(Array(StructField("colA", StringType), StructField("colB", StringType)))
242+
val df = spark.createDataFrame(rdd, schema)
243+
244+
val result = distinctCount.function(df)
245+
246+
assert(result.resultValue == "4")
247+
assert(result.resultValueType == ResultValueType.LongValue)
248+
}
190249
}

agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package za.co.absa.atum.agent.model
1919
import org.scalatest.flatspec.AnyFlatSpec
2020
import org.scalatest.matchers.should.Matchers
2121
import za.co.absa.atum.agent.AtumAgent
22-
import za.co.absa.atum.agent.model.AtumMeasure.{AbsSumOfValuesOfColumn, RecordCount, SumOfHashesOfColumn, SumOfValuesOfColumn}
22+
import za.co.absa.atum.agent.model.AtumMeasure.{AbsSumOfValuesOfColumn, RecordCount, SumOfHashesOfColumn, SumOfValuesOfColumn, SumOfTruncatedValuesOfColumn, AbsSumOfTruncatedValuesOfColumn}
2323
import za.co.absa.spark.commons.test.SparkTestBase
2424
import za.co.absa.atum.agent.AtumContext._
2525
import za.co.absa.atum.model.ResultValueType
@@ -30,11 +30,13 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se
3030
"Measure" should "be based on the dataframe" in {
3131

3232
// Measures
33-
val measureIds: AtumMeasure = RecordCount()
34-
val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn("salary")
33+
val measureIds: AtumMeasure = RecordCount()
34+
val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn("salary")
35+
val sumOfHashes: AtumMeasure = SumOfHashesOfColumn("id")
3536

36-
val salarySum = SumOfValuesOfColumn("salary")
37-
val sumOfHashes: AtumMeasure = SumOfHashesOfColumn("id")
37+
val salarySum = SumOfValuesOfColumn("salary")
38+
val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn("salary")
39+
val salaryTruncSum = SumOfTruncatedValuesOfColumn("salary")
3840

3941
// AtumContext contains `Measurement`
4042
val atumContextInstanceWithRecordCount = AtumAgent
@@ -83,12 +85,33 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se
8385
.removeMeasure(salaryAbsSum)
8486
)
8587

86-
val dfPersonCntResult = measureIds.function(dfPersons)
87-
val dfFullCntResult = measureIds.function(dfFull)
88-
val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull)
89-
val dfFullHashResult = sumOfHashes.function(dfFull)
90-
val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson)
91-
val dfFullSalarySumResult = salarySum.function(dfFull)
88+
val dfExtraPersonWithDecimalSalary = spark
89+
.createDataFrame(
90+
Seq(
91+
("id", "firstName", "lastName", "email", "email2", "profession", "3000.98"),
92+
("id", "firstName", "lastName", "email", "email2", "profession", "-1000.76")
93+
)
94+
)
95+
.toDF("id", "firstName", "lastName", "email", "email2", "profession", "salary")
96+
97+
val dfExtraDecimalPerson = dfExtraPersonWithDecimalSalary.union(dfPersons)
98+
99+
dfExtraDecimalPerson.createCheckpoint("a checkpoint name")(
100+
atumContextWithSalaryAbsMeasure
101+
.removeMeasure(measureIds)
102+
.removeMeasure(salaryAbsSum)
103+
)
104+
105+
val dfPersonCntResult = measureIds.function(dfPersons)
106+
val dfFullCntResult = measureIds.function(dfFull)
107+
val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull)
108+
val dfFullHashResult = sumOfHashes.function(dfFull)
109+
val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson)
110+
val dfFullSalarySumResult = salarySum.function(dfFull)
111+
val dfExtraPersonSalarySumTruncResult = salaryTruncSum.function(dfExtraDecimalPerson)
112+
val dfFullSalarySumTruncResult = salaryTruncSum.function(dfFull)
113+
val dfExtraPersonSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfExtraDecimalPerson)
114+
val dfFullSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfFull)
92115

93116
// Assertions
94117
assert(dfPersonCntResult.resultValue == "1000")
@@ -103,6 +126,14 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se
103126
assert(dfExtraPersonSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
104127
assert(dfFullSalarySumResult.resultValue == "2987144")
105128
assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
129+
assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144")
130+
assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.LongValue)
131+
assert(dfFullSalarySumTruncResult.resultValue == "2987144")
132+
assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.LongValue)
133+
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144")
134+
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue)
135+
assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144")
136+
assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.LongValue)
106137
}
107138

108139
}

agent/src/test/scala/za/co/absa/atum/agent/model/MeasuresBuilderUnitTests.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class MeasuresBuilderUnitTests extends AnyFlatSpecLike {
2828
MeasureDTO("distinctCount", Seq("distinctCountCol")),
2929
MeasureDTO("aggregatedTotal", Seq("aggregatedTotalCol")),
3030
MeasureDTO("absAggregatedTotal", Seq("absAggregatedTotalCol")),
31+
MeasureDTO("aggregatedTruncTotal", Seq("aggregatedTruncTotalCol")),
32+
MeasureDTO("absAggregatedTruncTotal", Seq("absAggregatedTruncTotalCol")),
3133
MeasureDTO("hashCrc32", Seq("hashCrc32Col"))
3234
)
3335

@@ -36,6 +38,8 @@ class MeasuresBuilderUnitTests extends AnyFlatSpecLike {
3638
DistinctRecordCount(Seq("distinctCountCol")),
3739
SumOfValuesOfColumn("aggregatedTotalCol"),
3840
AbsSumOfValuesOfColumn("absAggregatedTotalCol"),
41+
SumOfTruncatedValuesOfColumn("aggregatedTruncTotalCol"),
42+
AbsSumOfTruncatedValuesOfColumn("absAggregatedTruncTotalCol"),
3943
SumOfHashesOfColumn("hashCrc32Col")
4044
)
4145

0 commit comments

Comments
 (0)