Skip to content

Commit 007f45c

Browse files
committed
Changes Made:
- Added the aggregatedTruncTotal Measure and the absAggregatedTruncTotal Measure. - Added the tests for these Measures.
1 parent 39e500a commit 007f45c

File tree

5 files changed

+147
-11
lines changed

5 files changed

+147
-11
lines changed

agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ object AtumMeasure {
4444
DistinctRecordCount.measureName,
4545
SumOfValuesOfColumn.measureName,
4646
AbsSumOfValuesOfColumn.measureName,
47+
SumOfTruncatedValuesOfColumn.measureName,
48+
AbsSumOfTruncatedValuesOfColumn.measureName,
4749
SumOfHashesOfColumn.measureName
4850
)
4951

@@ -117,6 +119,42 @@ object AtumMeasure {
117119
def apply(measuredCol: String): AbsSumOfValuesOfColumn = AbsSumOfValuesOfColumn(measureName, measuredCol)
118120
}
119121

122+
case class SumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {
123+
//Cast to LongType to remove decimal points then cast back to decimal to ensure compatibility
124+
private val columnAggFn: Column => Column = column => sum(column.cast(LongType).cast(DecimalType(38, 0)))
125+
126+
override def function: MeasurementFunction = (ds: DataFrame) => {
127+
val dataType = ds.select(measuredCol).schema.fields(0).dataType
128+
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect()
129+
MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
130+
}
131+
132+
override def measuredColumns: Seq[String] = Seq(measuredCol)
133+
override val resultValueType: ResultValueType = ResultValueType.BigDecimalValue
134+
}
135+
object SumOfTruncatedValuesOfColumn {
136+
private[agent] val measureName: String = "aggregatedTruncTotal"
137+
def apply(measuredCol: String): SumOfTruncatedValuesOfColumn = SumOfTruncatedValuesOfColumn(measureName, measuredCol)
138+
}
139+
140+
case class AbsSumOfTruncatedValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {
141+
//Cast to LongType to remove decimal points then cast back to decimal to ensure compatibility
142+
private val columnAggFn: Column => Column = column => sum(abs(column.cast(LongType).cast(DecimalType(38, 0))))
143+
144+
override def function: MeasurementFunction = (ds: DataFrame) => {
145+
val dataType = ds.select(measuredCol).schema.fields(0).dataType
146+
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect()
147+
MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
148+
}
149+
150+
override def measuredColumns: Seq[String] = Seq(measuredCol)
151+
override val resultValueType: ResultValueType = ResultValueType.BigDecimalValue
152+
}
153+
object AbsSumOfTruncatedValuesOfColumn {
154+
private[agent] val measureName: String = "absAggregatedTruncTotal"
155+
def apply(measuredCol: String): AbsSumOfTruncatedValuesOfColumn = AbsSumOfTruncatedValuesOfColumn(measureName, measuredCol)
156+
}
157+
120158
case class SumOfHashesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {
121159
private val columnExpression: Column = sum(crc32(col(measuredCol).cast("String")))
122160
override def function: MeasurementFunction = (ds: DataFrame) => {

agent/src/main/scala/za/co/absa/atum/agent/model/MeasuresBuilder.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ private [agent] object MeasuresBuilder extends Logging {
4949
case DistinctRecordCount.measureName => DistinctRecordCount(measuredColumns)
5050
case SumOfValuesOfColumn.measureName => SumOfValuesOfColumn(measuredColumns.head)
5151
case AbsSumOfValuesOfColumn.measureName => AbsSumOfValuesOfColumn(measuredColumns.head)
52+
case SumOfTruncatedValuesOfColumn.measureName => SumOfTruncatedValuesOfColumn(measuredColumns.head)
53+
case AbsSumOfTruncatedValuesOfColumn.measureName => AbsSumOfTruncatedValuesOfColumn(measuredColumns.head)
5254
case SumOfHashesOfColumn.measureName => SumOfHashesOfColumn(measuredColumns.head)
5355
}
5456
}.toOption

agent/src/test/scala/za/co/absa/atum/agent/model/AtumMeasureUnitTests.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
3737
measuredCol = "salary"
3838
)
3939
val salarySum = SumOfValuesOfColumn(measuredCol = "salary")
40+
val salaryTruncSum = SumOfTruncatedValuesOfColumn(measuredCol = "salary")
41+
val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn(measuredCol = "salary")
4042
val sumOfHashes: AtumMeasure = SumOfHashesOfColumn(measuredCol = "id")
4143

4244
// AtumContext contains `Measurement`
@@ -86,12 +88,34 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
8688
.removeMeasure(salaryAbsSum)
8789
)
8890

91+
val dfExtraPersonWithDecimalSalary = spark
92+
.createDataFrame(
93+
Seq(
94+
("id", "firstName", "lastName", "email", "email2", "profession", "3000.98"),
95+
("id", "firstName", "lastName", "email", "email2", "profession", "-1000.76")
96+
)
97+
)
98+
.toDF("id", "firstName", "lastName", "email", "email2", "profession", "salary")
99+
100+
val dfExtraDecimalPerson = dfExtraPersonWithDecimalSalary.union(dfPersons)
101+
102+
dfExtraDecimalPerson.createCheckpoint("a checkpoint name")(
103+
atumContextWithSalaryAbsMeasure
104+
.removeMeasure(measureIds)
105+
.removeMeasure(salaryAbsSum)
106+
)
107+
108+
89109
val dfPersonCntResult = measureIds.function(dfPersons)
90110
val dfFullCntResult = measureIds.function(dfFull)
91111
val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull)
92112
val dfFullHashResult = sumOfHashes.function(dfFull)
93113
val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson)
94114
val dfFullSalarySumResult = salarySum.function(dfFull)
115+
val dfExtraPersonSalarySumTruncResult = salaryTruncSum.function(dfExtraDecimalPerson)
116+
val dfFullSalarySumTruncResult = salaryTruncSum.function(dfFull)
117+
val dfExtraPersonSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfExtraDecimalPerson)
118+
val dfFullSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfFull)
95119

96120
// Assertions
97121
assert(dfPersonCntResult.resultValue == "1000")
@@ -106,6 +130,14 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
106130
assert(dfExtraPersonSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
107131
assert(dfFullSalarySumResult.resultValue == "2987144")
108132
assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
133+
assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144")
134+
assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue)
135+
assert(dfFullSalarySumTruncResult.resultValue == "2987144")
136+
assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue)
137+
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144")
138+
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue)
139+
assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144")
140+
assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue)
109141
}
110142

111143
"AbsSumOfValuesOfColumn" should "return expected value" in {
@@ -187,4 +219,33 @@ class AtumMeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase
187219
assert(result.resultValueType == ResultValueType.BigDecimalValue)
188220
}
189221

222+
"SumTruncOfValuesOfColumn" should "return expected value" in {
223+
val distinctCount = SumOfTruncatedValuesOfColumn("colA")
224+
225+
val data = List(Row("1.98", "b1"), Row("-1.76", "b2"), Row("1.54", "b2"), Row("1.32", "b2"))
226+
val rdd = spark.sparkContext.parallelize(data)
227+
228+
val schema = StructType(Array(StructField("colA", StringType), StructField("colB", StringType)))
229+
val df = spark.createDataFrame(rdd, schema)
230+
231+
val result = distinctCount.function(df)
232+
233+
assert(result.resultValue == "2")
234+
assert(result.resultValueType == ResultValueType.BigDecimalValue)
235+
}
236+
237+
"AbsSumTruncOfValuesOfColumn" should "return expected value" in {
238+
val distinctCount = AbsSumOfTruncatedValuesOfColumn("colA")
239+
240+
val data = List(Row("1.98", "b1"), Row("-1.76", "b2"), Row("1.54", "b2"), Row("-1.32", "b2"))
241+
val rdd = spark.sparkContext.parallelize(data)
242+
243+
val schema = StructType(Array(StructField("colA", StringType), StructField("colB", StringType)))
244+
val df = spark.createDataFrame(rdd, schema)
245+
246+
val result = distinctCount.function(df)
247+
248+
assert(result.resultValue == "4")
249+
assert(result.resultValueType == ResultValueType.BigDecimalValue)
250+
}
190251
}

agent/src/test/scala/za/co/absa/atum/agent/model/MeasureUnitTests.scala

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package za.co.absa.atum.agent.model
1919
import org.scalatest.flatspec.AnyFlatSpec
2020
import org.scalatest.matchers.should.Matchers
2121
import za.co.absa.atum.agent.AtumAgent
22-
import za.co.absa.atum.agent.model.AtumMeasure.{AbsSumOfValuesOfColumn, RecordCount, SumOfHashesOfColumn, SumOfValuesOfColumn}
22+
import za.co.absa.atum.agent.model.AtumMeasure.{AbsSumOfValuesOfColumn, RecordCount, SumOfHashesOfColumn, SumOfValuesOfColumn, SumOfTruncatedValuesOfColumn, AbsSumOfTruncatedValuesOfColumn}
2323
import za.co.absa.spark.commons.test.SparkTestBase
2424
import za.co.absa.atum.agent.AtumContext._
2525
import za.co.absa.atum.model.ResultValueType
@@ -30,11 +30,13 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se
3030
"Measure" should "be based on the dataframe" in {
3131

3232
// Measures
33-
val measureIds: AtumMeasure = RecordCount()
34-
val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn("salary")
33+
val measureIds: AtumMeasure = RecordCount()
34+
val salaryAbsSum: AtumMeasure = AbsSumOfValuesOfColumn("salary")
35+
val sumOfHashes: AtumMeasure = SumOfHashesOfColumn("id")
3536

36-
val salarySum = SumOfValuesOfColumn("salary")
37-
val sumOfHashes: AtumMeasure = SumOfHashesOfColumn("id")
37+
val salarySum = SumOfValuesOfColumn("salary")
38+
val salaryAbsTruncSum = AbsSumOfTruncatedValuesOfColumn("salary")
39+
val salaryTruncSum = SumOfTruncatedValuesOfColumn("salary")
3840

3941
// AtumContext contains `Measurement`
4042
val atumContextInstanceWithRecordCount = AtumAgent
@@ -83,12 +85,33 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se
8385
.removeMeasure(salaryAbsSum)
8486
)
8587

86-
val dfPersonCntResult = measureIds.function(dfPersons)
87-
val dfFullCntResult = measureIds.function(dfFull)
88-
val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull)
89-
val dfFullHashResult = sumOfHashes.function(dfFull)
90-
val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson)
91-
val dfFullSalarySumResult = salarySum.function(dfFull)
88+
val dfExtraPersonWithDecimalSalary = spark
89+
.createDataFrame(
90+
Seq(
91+
("id", "firstName", "lastName", "email", "email2", "profession", "3000.98"),
92+
("id", "firstName", "lastName", "email", "email2", "profession", "-1000.76")
93+
)
94+
)
95+
.toDF("id", "firstName", "lastName", "email", "email2", "profession", "salary")
96+
97+
val dfExtraDecimalPerson = dfExtraPersonWithDecimalSalary.union(dfPersons)
98+
99+
dfExtraDecimalPerson.createCheckpoint("a checkpoint name")(
100+
atumContextWithSalaryAbsMeasure
101+
.removeMeasure(measureIds)
102+
.removeMeasure(salaryAbsSum)
103+
)
104+
105+
val dfPersonCntResult = measureIds.function(dfPersons)
106+
val dfFullCntResult = measureIds.function(dfFull)
107+
val dfFullSalaryAbsSumResult = salaryAbsSum.function(dfFull)
108+
val dfFullHashResult = sumOfHashes.function(dfFull)
109+
val dfExtraPersonSalarySumResult = salarySum.function(dfExtraPerson)
110+
val dfFullSalarySumResult = salarySum.function(dfFull)
111+
val dfExtraPersonSalarySumTruncResult = salaryTruncSum.function(dfExtraDecimalPerson)
112+
val dfFullSalarySumTruncResult = salaryTruncSum.function(dfFull)
113+
val dfExtraPersonSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfExtraDecimalPerson)
114+
val dfFullSalaryAbsSumTruncResult = salaryAbsTruncSum.function(dfFull)
92115

93116
// Assertions
94117
assert(dfPersonCntResult.resultValue == "1000")
@@ -103,6 +126,14 @@ class MeasureUnitTests extends AnyFlatSpec with Matchers with SparkTestBase { se
103126
assert(dfExtraPersonSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
104127
assert(dfFullSalarySumResult.resultValue == "2987144")
105128
assert(dfFullSalarySumResult.resultValueType == ResultValueType.BigDecimalValue)
129+
assert(dfExtraPersonSalarySumTruncResult.resultValue == "2989144")
130+
assert(dfExtraPersonSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue)
131+
assert(dfFullSalarySumTruncResult.resultValue == "2987144")
132+
assert(dfFullSalarySumTruncResult.resultValueType == ResultValueType.BigDecimalValue)
133+
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValue == "2991144")
134+
assert(dfExtraPersonSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue)
135+
assert(dfFullSalaryAbsSumTruncResult.resultValue == "2987144")
136+
assert(dfFullSalaryAbsSumTruncResult.resultValueType == ResultValueType.BigDecimalValue)
106137
}
107138

108139
}

agent/src/test/scala/za/co/absa/atum/agent/model/MeasuresBuilderUnitTests.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class MeasuresBuilderUnitTests extends AnyFlatSpecLike {
2828
MeasureDTO("distinctCount", Seq("distinctCountCol")),
2929
MeasureDTO("aggregatedTotal", Seq("aggregatedTotalCol")),
3030
MeasureDTO("absAggregatedTotal", Seq("absAggregatedTotalCol")),
31+
MeasureDTO("aggregatedTruncTotal", Seq("aggregatedTruncTotalCol")),
32+
MeasureDTO("absAggregatedTruncTotal", Seq("absAggregatedTruncTotalCol")),
3133
MeasureDTO("hashCrc32", Seq("hashCrc32Col"))
3234
)
3335

@@ -36,6 +38,8 @@ class MeasuresBuilderUnitTests extends AnyFlatSpecLike {
3638
DistinctRecordCount(Seq("distinctCountCol")),
3739
SumOfValuesOfColumn("aggregatedTotalCol"),
3840
AbsSumOfValuesOfColumn("absAggregatedTotalCol"),
41+
SumOfTruncatedValuesOfColumn("aggregatedTruncTotalCol"),
42+
AbsSumOfTruncatedValuesOfColumn("absAggregatedTruncTotalCol"),
3943
SumOfHashesOfColumn("hashCrc32Col")
4044
)
4145

0 commit comments

Comments
 (0)