Skip to content

Commit ecab7cb

Browse files
authored
#8 Make type defaults configurable
1 parent 0b94b6c commit ecab7cb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+291
-262
lines changed

src/main/scala/za/co/absa/standardization/SchemaValidator.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package za.co.absa.standardization
1818

1919
import org.apache.spark.sql.SparkSession
2020
import org.apache.spark.sql.types._
21-
import za.co.absa.standardization.types.{Defaults, GlobalDefaults, TypedStructField}
21+
import za.co.absa.standardization.types.{TypeDefaults, TypedStructField}
2222
import za.co.absa.standardization.validation.field.FieldValidationIssue
2323

2424
import scala.collection.mutable.ListBuffer
@@ -27,15 +27,13 @@ import scala.collection.mutable.ListBuffer
2727
* Object responsible for Spark schema validation against self inconsistencies (not against the actual data)
2828
*/
2929
object SchemaValidator {
30-
private implicit val defaults: Defaults = GlobalDefaults
31-
3230
/**
3331
* Validate a schema
3432
*
3533
* @param schema A Spark schema
3634
* @return A list of ValidationErrors objects, each containing a column name and the list of errors and warnings
3735
*/
38-
def validateSchema(schema: StructType): List[FieldValidationIssue] = {
36+
def validateSchema(schema: StructType)(implicit defaults: TypeDefaults): List[FieldValidationIssue] = {
3937
var errorsAccumulator = new ListBuffer[FieldValidationIssue]
4038
val flatSchema = flattenSchema(schema)
4139
for {s <- flatSchema} {

src/main/scala/za/co/absa/standardization/Standardization.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,17 @@ package za.co.absa.standardization
1818

1919
import org.apache.hadoop.conf.Configuration
2020
import org.apache.spark.sql.functions._
21-
import org.apache.spark.sql.types.{StructType, _}
21+
import org.apache.spark.sql.types._
2222
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
2323
import org.slf4j.{Logger, LoggerFactory}
2424

2525
import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements
2626
import za.co.absa.standardization.config.{DefaultStandardizationConfig, StandardizationConfig}
2727
import za.co.absa.standardization.stages.{SchemaChecker, TypeParser}
28-
import za.co.absa.standardization.types.{Defaults, GlobalDefaults, ParseOutput}
28+
import za.co.absa.standardization.types.{CommonTypeDefaults, ParseOutput, TypeDefaults}
2929
import za.co.absa.standardization.udf.{UDFLibrary, UDFNames}
3030

3131
object Standardization {
32-
private implicit val defaults: Defaults = GlobalDefaults
3332
private val logger: Logger = LoggerFactory.getLogger(this.getClass)
3433
final val DefaultColumnNameOfCorruptRecord = "_corrupt_record"
3534

@@ -41,6 +40,7 @@ object Standardization {
4140
(implicit sparkSession: SparkSession): DataFrame = {
4241
implicit val udfLib: UDFLibrary = new UDFLibrary(standardizationConfig)
4342
implicit val hadoopConf: Configuration = sparkSession.sparkContext.hadoopConfiguration
43+
implicit val defaults: TypeDefaults = standardizationConfig.typeDefaults
4444

4545
logger.info(s"Step 1: Schema validation")
4646
validateSchemaAgainstSelfInconsistencies(schema)
@@ -67,15 +67,15 @@ object Standardization {
6767

6868

6969
private def validateSchemaAgainstSelfInconsistencies(expSchema: StructType)
70-
(implicit spark: SparkSession): Unit = {
70+
(implicit spark: SparkSession, defaults: TypeDefaults): Unit = {
7171
val validationErrors = SchemaChecker.validateSchemaAndLog(expSchema)
7272
if (validationErrors._1.nonEmpty) {
7373
throw new ValidationException("A fatal schema validation error occurred.", validationErrors._1)
7474
}
7575
}
7676

7777
private def standardizeDataset(df: DataFrame, expSchema: StructType, stdConfig: StandardizationConfig)
78-
(implicit spark: SparkSession, udfLib: UDFLibrary): DataFrame = {
78+
(implicit spark: SparkSession, udfLib: UDFLibrary, defaults: TypeDefaults): DataFrame = {
7979

8080
val rowErrors: List[Column] = gatherRowErrors(df.schema)
8181
val (stdCols, errorCols, oldErrorColumn) = expSchema.fields.foldLeft(List.empty[Column], rowErrors, None: Option[Column]) {

src/main/scala/za/co/absa/standardization/config/BasicStandardizationConfig.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616

1717
package za.co.absa.standardization.config
1818

19+
import za.co.absa.standardization.types.{CommonTypeDefaults, TypeDefaults}
20+
1921
case class BasicStandardizationConfig(failOnInputNotPerSchema: Boolean,
2022
errorCodes: ErrorCodesConfig,
2123
metadataColumns: MetadataColumnsConfig,
24+
typeDefaults: TypeDefaults,
2225
errorColumn: String,
2326
timezone: String) extends StandardizationConfig
2427

@@ -28,6 +31,7 @@ object BasicStandardizationConfig {
2831
DefaultStandardizationConfig.failOnInputNotPerSchema,
2932
BasicErrorCodesConfig.fromDefault(),
3033
BasicMetadataColumnsConfig.fromDefault(),
34+
CommonTypeDefaults,
3135
DefaultStandardizationConfig.errorColumn,
3236
DefaultStandardizationConfig.timezone
3337
)

src/main/scala/za/co/absa/standardization/config/DefaultStandardizationConfig.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616

1717
package za.co.absa.standardization.config
1818

19+
import za.co.absa.standardization.types.{CommonTypeDefaults, TypeDefaults}
20+
1921
object DefaultStandardizationConfig extends StandardizationConfig {
2022
val errorCodes: ErrorCodesConfig = DefaultErrorCodesConfig
2123
val metadataColumns: MetadataColumnsConfig = DefaultMetadataColumnsConfig
2224
val failOnInputNotPerSchema: Boolean = false
25+
val typeDefaults: TypeDefaults = CommonTypeDefaults
2326
val errorColumn: String = "errCol"
2427
val timezone: String = "UTC"
2528
}

src/main/scala/za/co/absa/standardization/config/StandardizationConfig.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616

1717
package za.co.absa.standardization.config
1818

19+
import za.co.absa.standardization.types.TypeDefaults
20+
1921
trait StandardizationConfig {
2022
val failOnInputNotPerSchema: Boolean
2123
val errorCodes: ErrorCodesConfig
2224
val metadataColumns: MetadataColumnsConfig
25+
val typeDefaults: TypeDefaults
2326
val errorColumn: String
2427
val timezone: String
2528
}

src/main/scala/za/co/absa/standardization/stages/SchemaChecker.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import org.apache.log4j.{LogManager, Logger}
2020
import org.apache.spark.sql.SparkSession
2121
import org.apache.spark.sql.types.StructType
2222
import za.co.absa.standardization.SchemaValidator.{validateErrorColumn, validateSchema}
23+
import za.co.absa.standardization.types.TypeDefaults
2324
import za.co.absa.standardization.{ValidationError, ValidationIssue, ValidationWarning}
2425

2526
object SchemaChecker {
@@ -32,7 +33,7 @@ object SchemaChecker {
3233
* @param schema A Spark schema
3334
*/
3435
def validateSchemaAndLog(schema: StructType)
35-
(implicit spark: SparkSession): (Seq[String], Seq[String]) = {
36+
(implicit spark: SparkSession, defaults: TypeDefaults): (Seq[String], Seq[String]) = {
3637
val failures = validateSchema(schema) ::: validateErrorColumn(schema)
3738

3839
type ColName = String

src/main/scala/za/co/absa/standardization/stages/TypeParser.scala

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import za.co.absa.standardization.schema.StdSchemaUtils.FieldWithSource
3737
import za.co.absa.standardization.time.DateTimePattern
3838
import za.co.absa.standardization.typeClasses.{DoubleLike, LongLike}
3939
import za.co.absa.standardization.types.TypedStructField._
40-
import za.co.absa.standardization.types.{Defaults, ParseOutput, TypedStructField}
40+
import za.co.absa.standardization.types.{TypeDefaults, ParseOutput, TypedStructField}
4141
import za.co.absa.standardization.udf.{UDFBuilder, UDFLibrary, UDFNames}
4242

4343
import scala.reflect.runtime.universe._
@@ -135,7 +135,7 @@ object TypeParser {
135135
origSchema: StructType,
136136
stdConfig: StandardizationConfig,
137137
failOnInputNotPerSchema: Boolean = true)
138-
(implicit udfLib: UDFLibrary, defaults: Defaults): ParseOutput = {
138+
(implicit udfLib: UDFLibrary, defaults: TypeDefaults): ParseOutput = {
139139
// udfLib implicit is present for error column UDF implementation
140140
val sourceName = SchemaUtils.appendPath(path, field.sourceName)
141141
val origField = origSchema.getField(sourceName)
@@ -162,7 +162,7 @@ object TypeParser {
162162
origType: DataType,
163163
failOnInputNotPerSchema: Boolean,
164164
isArrayElement: Boolean = false)
165-
(implicit defaults: Defaults): TypeParser[_] = {
165+
(implicit defaults: TypeDefaults): TypeParser[_] = {
166166
val parserClass: (String, Column, DataType, Boolean, Boolean) => TypeParser[_] = field.dataType match {
167167
case _: ArrayType => ArrayParser(TypedStructField.asArrayTypeStructField(field), _, _, _, _, _)
168168
case _: StructType => StructParser(TypedStructField.asStructTypeStructField(field), _, _, _, _, _)
@@ -191,7 +191,7 @@ object TypeParser {
191191
origType: DataType,
192192
failOnInputNotPerSchema: Boolean,
193193
isArrayElement: Boolean)
194-
(implicit defaults: Defaults) extends TypeParser[Any] {
194+
(implicit defaults: TypeDefaults) extends TypeParser[Any] {
195195

196196
override def fieldType: ArrayType = {
197197
field.dataType
@@ -226,7 +226,7 @@ object TypeParser {
226226
origType: DataType,
227227
failOnInputNotPerSchema: Boolean,
228228
isArrayElement: Boolean)
229-
(implicit defaults: Defaults) extends TypeParser[Any] {
229+
(implicit defaults: TypeDefaults) extends TypeParser[Any] {
230230
override def fieldType: StructType = {
231231
field.dataType
232232
}
@@ -260,7 +260,7 @@ object TypeParser {
260260
}
261261
}
262262

263-
private abstract class PrimitiveParser[T](implicit defaults: Defaults) extends TypeParser[T] {
263+
private abstract class PrimitiveParser[T](implicit defaults: TypeDefaults) extends TypeParser[T] {
264264
override protected def standardizeAfterCheck(stdConfig: StandardizationConfig)(implicit logger: Logger): ParseOutput = {
265265
val castedCol: Column = assemblePrimitiveCastLogic
266266
val castHasError: Column = assemblePrimitiveCastErrorLogic(castedCol)
@@ -298,12 +298,12 @@ object TypeParser {
298298
}
299299
}
300300

301-
private abstract class ScalarParser[T](implicit defaults: Defaults) extends PrimitiveParser[T] {
301+
private abstract class ScalarParser[T](implicit defaults: TypeDefaults) extends PrimitiveParser[T] {
302302
override def assemblePrimitiveCastLogic: Column = column.cast(field.dataType)
303303
}
304304

305305
private abstract class NumericParser[N: TypeTag](override val field: NumericTypeStructField[N])
306-
(implicit defaults: Defaults) extends ScalarParser[N] {
306+
(implicit defaults: TypeDefaults) extends ScalarParser[N] {
307307
override protected def standardizeAfterCheck(stdConfig: StandardizationConfig)(implicit logger: Logger): ParseOutput = {
308308
if (field.needsUdfParsing) {
309309
standardizeUsingUdf(stdConfig)
@@ -355,7 +355,7 @@ object TypeParser {
355355
failOnInputNotPerSchema: Boolean,
356356
isArrayElement: Boolean,
357357
overflowableTypes: Set[DataType])
358-
(implicit defaults: Defaults) extends NumericParser[N](field) {
358+
(implicit defaults: TypeDefaults) extends NumericParser[N](field) {
359359
override protected def assemblePrimitiveCastErrorLogic(castedCol: Column): Column = {
360360
val basicLogic: Column = super.assemblePrimitiveCastErrorLogic(castedCol)
361361

@@ -385,7 +385,7 @@ object TypeParser {
385385
origType: DataType,
386386
failOnInputNotPerSchema: Boolean,
387387
isArrayElement: Boolean)
388-
(implicit defaults: Defaults)
388+
(implicit defaults: TypeDefaults)
389389
extends NumericParser[BigDecimal](field)
390390
// NB! loss of precision is not addressed for any DecimalType
391391
// e.g. 3.141592 will be Standardized to Decimal(10,2) as 3.14
@@ -396,7 +396,7 @@ object TypeParser {
396396
origType: DataType,
397397
failOnInputNotPerSchema: Boolean,
398398
isArrayElement: Boolean)
399-
(implicit defaults: Defaults)
399+
(implicit defaults: TypeDefaults)
400400
extends NumericParser[N](field) {
401401
override protected def assemblePrimitiveCastErrorLogic(castedCol: Column): Column = {
402402
//NB! loss of precision is not addressed for any fractional type
@@ -414,15 +414,15 @@ object TypeParser {
414414
origType: DataType,
415415
failOnInputNotPerSchema: Boolean,
416416
isArrayElement: Boolean)
417-
(implicit defaults: Defaults) extends ScalarParser[String]
417+
(implicit defaults: TypeDefaults) extends ScalarParser[String]
418418

419419
private final case class BinaryParser(field: BinaryTypeStructField,
420420
path: String,
421421
column: Column,
422422
origType: DataType,
423423
failOnInputNotPerSchema: Boolean,
424424
isArrayElement: Boolean)
425-
(implicit defaults: Defaults) extends PrimitiveParser[Array[Byte]] {
425+
(implicit defaults: TypeDefaults) extends PrimitiveParser[Array[Byte]] {
426426
override protected def assemblePrimitiveCastLogic: Column = {
427427
origType match {
428428
case BinaryType => column
@@ -450,7 +450,7 @@ object TypeParser {
450450
origType: DataType,
451451
failOnInputNotPerSchema: Boolean,
452452
isArrayElement: Boolean)
453-
(implicit defaults: Defaults) extends ScalarParser[Boolean]
453+
(implicit defaults: TypeDefaults) extends ScalarParser[Boolean]
454454

455455
/**
456456
* Timestamp conversion logic
@@ -474,7 +474,7 @@ object TypeParser {
474474
* Date | O | ->to_utc_timestamp->to_date
475475
* Other | ->String->to_date | ->String->to_timestamp->to_utc_timestamp->to_date
476476
*/
477-
private abstract class DateTimeParser[T](implicit defaults: Defaults) extends PrimitiveParser[T] {
477+
private abstract class DateTimeParser[T](implicit defaults: TypeDefaults) extends PrimitiveParser[T] {
478478
override val field: DateTimeTypeStructField[T]
479479
protected val pattern: DateTimePattern = field.pattern.get.get
480480

@@ -551,11 +551,11 @@ object TypeParser {
551551
origType: DataType,
552552
failOnInputNotPerSchema: Boolean,
553553
isArrayElement: Boolean)
554-
(implicit defaults: Defaults) extends DateTimeParser[Date] {
554+
(implicit defaults: TypeDefaults) extends DateTimeParser[Date] {
555555
private val defaultTimeZone: Option[String] = field
556556
.defaultTimeZone
557557
.map(Option(_))
558-
.getOrElse(defaults.getDefaultDateTimeZone)
558+
.getOrElse(defaults.defaultDateTimeZone)
559559

560560
private def applyPatternToStringColumn(column: Column, pattern: String): Column = {
561561
defaultTimeZone.map(tz =>
@@ -605,12 +605,12 @@ object TypeParser {
605605
origType: DataType,
606606
failOnInputNotPerSchema: Boolean,
607607
isArrayElement: Boolean)
608-
(implicit defaults: Defaults) extends DateTimeParser[Timestamp] {
608+
(implicit defaults: TypeDefaults) extends DateTimeParser[Timestamp] {
609609

610610
private val defaultTimeZone: Option[String] = field
611611
.defaultTimeZone
612612
.map(Option(_))
613-
.getOrElse(defaults.getDefaultTimestampTimeZone)
613+
.getOrElse(defaults.defaultTimestampTimeZone)
614614

615615
private def applyPatternToStringColumn(column: Column, pattern: String): Column = {
616616
val interim: Column = to_timestamp(column, pattern)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright 2021 ABSA Group Limited
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package za.co.absa.standardization.types
18+
19+
import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampType}
20+
import za.co.absa.standardization.numeric.DecimalSymbols
21+
22+
import java.sql.{Date, Timestamp}
23+
import java.util.Locale
24+
import scala.util.{Success, Try}
25+
26+
class CommonTypeDefaults extends TypeDefaults {
27+
val integerTypeDefault: Int = 0
28+
val floatTypeDefault: Float = 0f
29+
val byteTypeDefault: Byte = 0.toByte
30+
val shortTypeDefault: Short = 0.toShort
31+
val doubleTypeDefault: Double = 0.0d
32+
val longTypeDefault: Long = 0L
33+
val stringTypeDefault: String = ""
34+
val binaryTypeDefault: Array[Byte] = Array.empty[Byte]
35+
val dateTypeDefault: Date = new Date(0) // Linux epoch
36+
val timestampTypeDefault: Timestamp = new Timestamp(0)
37+
val booleanTypeDefault: Boolean = false
38+
val decimalTypeDefault: (Int, Int) => BigDecimal = { (precision, scale) =>
39+
val beforeFloatingPoint = "0" * (precision - scale)
40+
val afterFloatingPoint = "0" * scale
41+
BigDecimal(s"$beforeFloatingPoint.$afterFloatingPoint")
42+
}
43+
44+
override def defaultTimestampTimeZone: Option[String] = None
45+
override def defaultDateTimeZone: Option[String] = None
46+
47+
override def getDecimalSymbols: DecimalSymbols = DecimalSymbols(Locale.US)
48+
49+
override def getDataTypeDefaultValue(dt: DataType): Any =
50+
dt match {
51+
case _: IntegerType => integerTypeDefault
52+
case _: FloatType => floatTypeDefault
53+
case _: ByteType => byteTypeDefault
54+
case _: ShortType => shortTypeDefault
55+
case _: DoubleType => doubleTypeDefault
56+
case _: LongType => longTypeDefault
57+
case _: StringType => stringTypeDefault
58+
case _: BinaryType => binaryTypeDefault
59+
case _: DateType => dateTypeDefault
60+
case _: TimestampType => timestampTypeDefault
61+
case _: BooleanType => booleanTypeDefault
62+
case t: DecimalType => decimalTypeDefault(t.precision, t.scale)
63+
case _ => throw new IllegalStateException(s"No default value defined for data type ${dt.typeName}")
64+
}
65+
66+
override def getDataTypeDefaultValueWithNull(dt: DataType, nullable: Boolean): Try[Option[Any]] = {
67+
if (nullable) {
68+
Success(None)
69+
} else {
70+
Try{
71+
getDataTypeDefaultValue(dt)
72+
}.map(Some(_))
73+
}
74+
}
75+
76+
override def getStringPattern(dt: DataType): String = dt match {
77+
case DateType => "yyyy-MM-dd"
78+
case TimestampType => "yyyy-MM-dd HH:mm:ss"
79+
case _: IntegerType
80+
| FloatType
81+
| ByteType
82+
| ShortType
83+
| DoubleType
84+
| LongType => ""
85+
case _: DecimalType => ""
86+
case _ => throw new IllegalStateException(s"No default format defined for data type ${dt.typeName}")
87+
}
88+
}
89+
90+
object CommonTypeDefaults extends CommonTypeDefaults

0 commit comments

Comments
 (0)