Skip to content

Commit f9fbb49

Browse files
yruslancerveada
andauthored
#58 Improve casting errors in the error column (#59)
* WIP * Fix more tests. * Remove stdSparkTestBase * Fix tests failing in CI/CD * Apply suggestions from code review Co-authored-by: Adam Cervenka <[email protected]> * Fix some of PR suggestions. * Fix another set of PR suggestions - parsers no longer depend on source and target string representation of types. * Remove more redundant stuff introduced in the PR. * Remove unnecessary imports. * Remove more unnecessary imports. * #58 Include default patterns in casting errors. * Fixup * #58 According to PR suggestions changed the casting error message. --------- Co-authored-by: Adam Cervenka <[email protected]>
1 parent 685b658 commit f9fbb49

23 files changed

+390
-327
lines changed

src/main/scala/za/co/absa/standardization/StandardizationErrorMessage.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,24 @@
1616

1717
package za.co.absa.standardization
1818

19-
import za.co.absa.standardization.ErrorMessage
20-
import za.co.absa.standardization.config.{ErrorCodesConfig}
19+
import za.co.absa.standardization.config.ErrorCodesConfig
2120

2221
object StandardizationErrorMessage {
2322

24-
def stdCastErr(errCol: String, rawValue: String)(implicit errorCodes: ErrorCodesConfig): ErrorMessage = ErrorMessage(
25-
"stdCastError",
26-
errorCodes.castError,
27-
"Standardization Error - Type cast",
28-
errCol,
29-
Seq(rawValue))
23+
def stdCastErr(errCol: String, rawValue: String, sourceType: String, targetType: String, pattern: Option[String])(implicit errorCodes: ErrorCodesConfig): ErrorMessage = {
24+
val sourceTypeFull = pattern match {
25+
case Some(pattern) if pattern.nonEmpty => s"'$sourceType' ($pattern)"
26+
case _ => s"'$sourceType'"
27+
}
28+
29+
ErrorMessage(
30+
"stdCastError",
31+
errorCodes.castError,
32+
s"Cast from $sourceTypeFull to '$targetType'",
33+
errCol,
34+
Seq(rawValue))
35+
}
36+
3037
def stdNullErr(errCol: String)(implicit errorCodes: ErrorCodesConfig): ErrorMessage = ErrorMessage(
3138
"stdNullError",
3239
errorCodes.nullError,

src/main/scala/za/co/absa/standardization/stages/TypeParser.scala

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,30 @@
1616

1717
package za.co.absa.standardization.stages
1818

19-
import java.security.InvalidParameterException
20-
import java.sql.Timestamp
21-
import java.util.Date
22-
import java.util.regex.Pattern
2319
import org.apache.spark.sql.Column
2420
import org.apache.spark.sql.expressions.UserDefinedFunction
2521
import org.apache.spark.sql.functions._
2622
import org.apache.spark.sql.types._
2723
import org.slf4j.{Logger, LoggerFactory}
28-
import za.co.absa.standardization.ErrorMessage
2924
import za.co.absa.spark.commons.implicits.ColumnImplicits.ColumnEnhancements
3025
import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements
3126
import za.co.absa.spark.commons.utils.SchemaUtils
3227
import za.co.absa.spark.hofs.transform
33-
import za.co.absa.standardization.StandardizationErrorMessage
28+
import za.co.absa.standardization.{ErrorMessage, StandardizationErrorMessage}
3429
import za.co.absa.standardization.config.StandardizationConfig
3530
import za.co.absa.standardization.implicits.StdColumnImplicits.StdColumnEnhancements
36-
import za.co.absa.standardization.schema.{MetadataValues, StdSchemaUtils}
3731
import za.co.absa.standardization.schema.StdSchemaUtils.FieldWithSource
32+
import za.co.absa.standardization.schema.{MetadataValues, StdSchemaUtils}
3833
import za.co.absa.standardization.time.DateTimePattern
3934
import za.co.absa.standardization.typeClasses.{DoubleLike, LongLike}
4035
import za.co.absa.standardization.types.TypedStructField._
4136
import za.co.absa.standardization.types.{ParseOutput, TypeDefaults, TypedStructField}
4237
import za.co.absa.standardization.udf.{UDFBuilder, UDFNames}
4338

39+
import java.security.InvalidParameterException
40+
import java.sql.Timestamp
41+
import java.util.Date
42+
import java.util.regex.Pattern
4443
import scala.reflect.runtime.universe._
4544
import scala.util.{Random, Try}
4645

@@ -265,18 +264,30 @@ object TypeParser {
265264
override protected def standardizeAfterCheck(stdConfig: StandardizationConfig)(implicit logger: Logger): ParseOutput = {
266265
val castedCol: Column = assemblePrimitiveCastLogic
267266
val castHasError: Column = assemblePrimitiveCastErrorLogic(castedCol)
267+
val patternOpt = field.pattern.toOption.flatten.map(_.pattern)
268+
val patternColumn = lit(patternOpt.orNull)
268269

269270
val err: Column = if (field.nullable) {
270271
when(column.isNotNull and castHasError, // cast failed
271-
array(callUDF(UDFNames.stdCastErr, lit(columnIdForUdf), column.cast(StringType)))
272+
array(callUDF(UDFNames.stdCastErr,
273+
lit(columnIdForUdf),
274+
column.cast(StringType),
275+
lit(origType.typeName),
276+
lit(field.dataType.typeName),
277+
patternColumn))
272278
).otherwise( // everything is OK
273279
typedLit(Seq.empty[ErrorMessage])
274280
)
275281
} else {
276282
when(column.isNull, // NULL not allowed
277283
array(callUDF(UDFNames.stdNullErr, lit(columnIdForUdf)))
278284
).otherwise( when(castHasError, // cast failed
279-
array(callUDF(UDFNames.stdCastErr, lit(columnIdForUdf), column.cast(StringType)))
285+
array(callUDF(UDFNames.stdCastErr,
286+
lit(columnIdForUdf),
287+
column.cast(StringType),
288+
lit(origType.typeName),
289+
lit(field.dataType.typeName),
290+
patternColumn))
280291
).otherwise( // everything is OK
281292
typedLit(Seq.empty[ErrorMessage])
282293
))
@@ -344,7 +355,15 @@ object TypeParser {
344355
}
345356

346357
private def standardizeUsingUdf(stdConfig: StandardizationConfig): ParseOutput = {
347-
val udfFnc: UserDefinedFunction = UDFBuilder.stringUdfViaNumericParser(field.parser.get, field.nullable, columnIdForUdf, stdConfig, defaultValue)
358+
val udfFnc: UserDefinedFunction = UDFBuilder.stringUdfViaNumericParser(
359+
origType,
360+
field.dataType,
361+
field.parser.get,
362+
field.nullable,
363+
columnIdForUdf,
364+
stdConfig,
365+
defaultValue
366+
)
348367
ParseOutput(udfFnc(column)("result").cast(field.dataType).as(fieldOutputName), udfFnc(column)("error"))
349368
}
350369
}

src/main/scala/za/co/absa/standardization/types/TypedStructField.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616

1717
package za.co.absa.standardization.types
1818

19-
import java.sql.{Date, Timestamp}
20-
import java.util.Base64
21-
2219
import org.apache.spark.sql.types._
2320
import za.co.absa.spark.commons.implicits.StructFieldImplicits.StructFieldMetadataEnhancements
2421
import za.co.absa.standardization.ValidationIssue
@@ -28,6 +25,9 @@ import za.co.absa.standardization.time.DateTimePattern
2825
import za.co.absa.standardization.typeClasses.{DoubleLike, LongLike}
2926
import za.co.absa.standardization.types.parsers._
3027
import za.co.absa.standardization.validation.field._
28+
29+
import java.sql.{Date, Timestamp}
30+
import java.util.Base64
3131
import scala.util.{Failure, Success, Try}
3232

3333
sealed abstract class TypedStructField(val structField: StructField)(implicit defaults: TypeDefaults)
@@ -218,7 +218,9 @@ object TypedStructField {
218218
}
219219

220220
// NumericTypeStructField
221-
sealed abstract class NumericTypeStructField[N](structField: StructField, val typeMin: N, val typeMax: N)
221+
sealed abstract class NumericTypeStructField[N](structField: StructField,
222+
val typeMin: N,
223+
val typeMax: N)
222224
(implicit defaults: TypeDefaults)
223225
extends TypedStructFieldTagged[N](structField) {
224226
val allowInfinity: Boolean = false
@@ -277,8 +279,11 @@ object TypedStructField {
277279
val decimalSymbols = patternForParser.map(_.decimalSymbols).getOrElse(defaults.getDecimalSymbols)
278280
Try(IntegralParser.ofRadix(radix, decimalSymbols, Option(typeMin), Option(typeMax)))
279281
} else {
280-
Success(IntegralParser(patternForParser
281-
.getOrElse(NumericPattern(defaults.getDecimalSymbols)), Option(typeMin), Option(typeMax)))
282+
Success(IntegralParser(
283+
patternForParser.getOrElse(NumericPattern(defaults.getDecimalSymbols)),
284+
Option(typeMin),
285+
Option(typeMax)
286+
))
282287
}}
283288
}
284289

src/main/scala/za/co/absa/standardization/types/parsers/FractionalParser.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
package za.co.absa.standardization.types.parsers
1818

19-
import java.text.DecimalFormat
20-
2119
import za.co.absa.standardization.numeric.NumericPattern
2220
import za.co.absa.standardization.typeClasses.DoubleLike
2321

22+
import java.text.DecimalFormat
23+
2424
class FractionalParser[D: DoubleLike] private(override val pattern: NumericPattern,
2525
override val min: Option[D],
2626
override val max: Option[D])

src/main/scala/za/co/absa/standardization/udf/UDFBuilder.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package za.co.absa.standardization.udf
1818

1919
import org.apache.spark.sql.expressions.UserDefinedFunction
2020
import org.apache.spark.sql.functions.udf
21+
import org.apache.spark.sql.types.DataType
2122
import za.co.absa.standardization.config.StandardizationConfig
2223
import za.co.absa.standardization.types.parsers.NumericParser
2324
import za.co.absa.standardization.types.parsers.NumericParser.NumericParserException
@@ -26,7 +27,9 @@ import scala.reflect.runtime.universe._
2627
import scala.util.{Failure, Success}
2728

2829
object UDFBuilder {
29-
def stringUdfViaNumericParser[T: TypeTag](parser: NumericParser[T],
30+
def stringUdfViaNumericParser[T: TypeTag](sourceDataType: DataType,
31+
targetDataType: DataType,
32+
parser: NumericParser[T],
3033
columnNullable: Boolean,
3134
columnNameForError: String,
3235
stdConfig: StandardizationConfig,
@@ -38,10 +41,12 @@ object UDFBuilder {
3841
val vColumnNullable = columnNullable
3942
val vStdConfig = stdConfig
4043

41-
udf[UDFResult[T], String](numericParserToTyped(_, vParser, vColumnNullable, vColumnNameForError, vStdConfig, vDefaultValue))
44+
udf[UDFResult[T], String](numericParserToTyped(_, sourceDataType, targetDataType, vParser, vColumnNullable, vColumnNameForError, vStdConfig, vDefaultValue))
4245
}
4346

4447
private def numericParserToTyped[T](input: String,
48+
sourceDataType: DataType,
49+
targetDataType: DataType,
4550
parser: NumericParser[T],
4651
columnNullable: Boolean,
4752
columnNameForError: String,
@@ -52,7 +57,7 @@ object UDFBuilder {
5257
case None if columnNullable => Success(None)
5358
case None => Failure(nullException)
5459
}
55-
UDFResult.fromTry(result, columnNameForError, input, stdConfig, defaultValue)
60+
UDFResult.fromTry(result, columnNameForError, input, sourceDataType.typeName, targetDataType.typeName, None, stdConfig, defaultValue)
5661
}
5762

5863
private val nullException = new NumericParserException("Null value on input for non-nullable field")

src/main/scala/za/co/absa/standardization/udf/UDFLibrary.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ class UDFLibrary(stdConfig: StandardizationConfig) extends OncePerSparkSession w
3535

3636
override protected def registerBody(spark: SparkSession): Unit = {
3737

38-
spark.udf.register(stdCastErr, { (errCol: String, rawValue: String) =>
39-
StandardizationErrorMessage.stdCastErr(errCol, rawValue)
38+
spark.udf.register(stdCastErr, { (errCol: String, rawValue: String, sourceType: String, targetType: String, pattern: String) =>
39+
StandardizationErrorMessage.stdCastErr(errCol, rawValue, sourceType, targetType, Option(pattern))
4040
})
4141

4242
spark.udf.register(stdNullErr, { errCol: String => StandardizationErrorMessage.stdNullErr(errCol) })

src/main/scala/za/co/absa/standardization/udf/UDFResult.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,19 @@ object UDFResult {
3030
UDFResult(result, Seq.empty)
3131
}
3232

33-
def fromTry[T](result: Try[Option[T]], columnName: String, rawValue: String, stdConfig: StandardizationConfig, defaultValue: Option[T] = None): UDFResult[T] = {
33+
def fromTry[T](result: Try[Option[T]],
34+
columnName: String,
35+
rawValue: String,
36+
sourceType: String,
37+
targetType: String,
38+
pattern: Option[String],
39+
stdConfig: StandardizationConfig,
40+
defaultValue: Option[T] = None): UDFResult[T] = {
3441
result match {
3542
case Success(success) => UDFResult.success(success)
3643
case Failure(_) if Option(rawValue).isEmpty => UDFResult(defaultValue, Seq(StandardizationErrorMessage.stdNullErr(columnName)(stdConfig.errorCodes)))
37-
case Failure(_) => UDFResult(defaultValue, Seq(StandardizationErrorMessage.stdCastErr(columnName, rawValue)(stdConfig.errorCodes)))
44+
case Failure(_) =>
45+
UDFResult(defaultValue, Seq(StandardizationErrorMessage.stdCastErr(columnName, rawValue, sourceType, targetType, pattern)(stdConfig.errorCodes)))
3846
}
3947
}
4048
}

src/test/scala/za/co/absa/standardization/TestSamples.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ object TestSamples {
8585
val resData = List(
8686
StdEmployee(name = "John0", surname = "Unknown Surname", hoursWorked = Some(List(8, 7, 8, 9, 12, 0)),
8787
employeeNumbers = List(EmployeeNumberStd("SAP", List(456, 123)), EmployeeNumberStd("WD", List(5))), startDate = new java.sql.Date(startDate), errCol = List(StandardizationErrorMessage.stdNullErr("surname"), StandardizationErrorMessage.stdNullErr("hoursWorked[*]"))),
88-
StdEmployee(name = "John1", surname = "Doe1", hoursWorked = Some(List(99, 99, 76, 12, 12, 24)), startDate = new java.sql.Date(0), errCol = List(StandardizationErrorMessage.stdCastErr("startDate", "Two Thousand Something"))),
88+
StdEmployee(name = "John1", surname = "Doe1", hoursWorked = Some(List(99, 99, 76, 12, 12, 24)), startDate = new java.sql.Date(0), errCol = List(StandardizationErrorMessage.stdCastErr("startDate", "Two Thousand Something", "string", "date", Some("yyyy-MM-dd")))),
8989
StdEmployee(name = "John2", surname = "Unknown Surname", hoursWorked = None, startDate = new java.sql.Date(startDate), updated = Some(Timestamp.valueOf("2015-07-16 13:32:24")), errCol = List(StandardizationErrorMessage.stdNullErr("surname"), StandardizationErrorMessage.stdNullErr("hoursWorked"))),
9090
StdEmployee(name = "John3", surname = "Unknown Surname", hoursWorked = Some(List()), startDate = new java.sql.Date(startDate), updated = Some(Timestamp.valueOf("2015-07-16 10:32:24")), errCol = List(StandardizationErrorMessage.stdNullErr("surname"))))
9191

src/test/scala/za/co/absa/standardization/interpreter/DateTimeSuite.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@ class DateTimeSuite extends AnyFunSuite with SparkTestBase with LoggerTestBase {
107107
null,
108108
ts, ts, ts, null, ts0, ts0,
109109
List(
110-
StandardizationErrorMessage.stdCastErr("dateSampleWrong1","10-20-2017"),
111-
StandardizationErrorMessage.stdCastErr("dateSampleWrong2","201711"),
112-
StandardizationErrorMessage.stdCastErr("dateSampleWrong3",""),
113-
StandardizationErrorMessage.stdCastErr("timestampSampleWrong1", "20171020T081131"),
114-
StandardizationErrorMessage.stdCastErr("timestampSampleWrong2", "2017-10-20t081131"),
115-
StandardizationErrorMessage.stdCastErr("timestampSampleWrong3", "2017-10-20")
110+
StandardizationErrorMessage.stdCastErr("dateSampleWrong1","10-20-2017", "string", "date", Some("dd-MM-yyyy")),
111+
StandardizationErrorMessage.stdCastErr("dateSampleWrong2","201711", "string", "date", Some("dd-MM-yyyy")),
112+
StandardizationErrorMessage.stdCastErr("dateSampleWrong3","", "string", "date", Some("dd-MM-yyyy")),
113+
StandardizationErrorMessage.stdCastErr("timestampSampleWrong1", "20171020T081131", "string", "timestamp", Some("yyyy-MM-dd'T'HH:mm:ss")),
114+
StandardizationErrorMessage.stdCastErr("timestampSampleWrong2", "2017-10-20t081131", "string", "timestamp", Some("yyyy-MM-dd'T'HH:mm:ss")),
115+
StandardizationErrorMessage.stdCastErr("timestampSampleWrong3", "2017-10-20", "string", "timestamp", Some("yyyy-MM-dd'T'HH:mm:ss"))
116116
)
117117
))
118118
val std: Dataset[Row] = Standardization.standardize(data, schemaOk, stdConfig)

0 commit comments

Comments
 (0)