|
16 | 16 |
|
17 | 17 | package za.co.absa.standardization.stages |
18 | 18 |
|
19 | | -import org.apache.spark.sql.functions.{to_timestamp,lit, when,coalesce,to_date} |
20 | | -import org.apache.spark.sql.{Column, Row, SparkSession} |
21 | | -import org.apache.spark.sql.types.{DataType, DateType, StringType, StructField, StructType, TimestampType} |
22 | | -import za.co.absa.standardization.types.{TypeDefaults, TypedStructField} |
23 | | -import za.co.absa.standardization.types.TypedStructField.DateTimeTypeStructField |
24 | | -import java.sql.Timestamp |
25 | | -import scala.collection.JavaConverters._ |
| 19 | +import org.apache.spark.sql.functions.{lit, when} |
| 20 | +import org.apache.spark.sql.Column |
| 21 | +import org.apache.spark.sql.types.{DataType, DateType, TimestampType} |
| 22 | +import za.co.absa.standardization.types.parsers.DateTimeParser |
| 23 | +import za.co.absa.standardization.time.{DateTimePattern, InfinityConfig} |
| 24 | + |
| 25 | +import java.sql.{Date, Timestamp} |
26 | 26 | import java.text.SimpleDateFormat |
27 | | -import java.util.Date |
| 27 | +import java.util.Locale |
| 28 | +import scala.util.Try |
28 | 29 |
|
29 | 30 |
|
30 | 31 |
|
31 | 32 | trait InfinitySupport { |
32 | 33 | protected def infMinusSymbol: Option[String] |
| 34 | + |
33 | 35 | protected def infMinusValue: Option[String] |
| 36 | + |
34 | 37 | protected def infPlusSymbol: Option[String] |
| 38 | + |
35 | 39 | protected def infPlusValue: Option[String] |
| 40 | + protected def infMinusPattern: Option[String] |
| 41 | + protected def infPlusPattern: Option[String] |
36 | 42 | protected val origType: DataType |
37 | | - protected def field: TypedStructField |
38 | | - |
39 | | - |
40 | | - private def sanitizeInput(s: String): String = { |
41 | | - if (s.matches("[a-zA-Z0-9:.-]+")) s |
42 | | - else { |
43 | | - throw new IllegalArgumentException(s"Invalid input '$s': must be alphanumeric , colon, dot or hyphen") |
44 | | - } |
45 | | - } |
46 | | - |
47 | | - private def getPattern(dataType: DataType): Option[String] = { |
48 | | - dataType match { |
49 | | - case DateType | TimestampType => |
50 | | - field match { |
51 | | - case dateField: DateTimeTypeStructField[_] => |
52 | | - dateField.pattern.toOption.flatten.map(_.pattern) |
53 | | - case _ => None |
54 | | - } |
55 | | - case _ => None |
56 | | - } |
57 | | - } |
58 | | - |
59 | | - private def validateAndConvertInfinityValue(value: String, dataType: DataType, patternOpt: Option[String], spark:SparkSession): String = { |
60 | | - val sanitizedValue = sanitizeInput(value) |
61 | | - val schema = StructType(Seq(StructField("value", StringType, nullable = false))) |
62 | | - val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq(Row(sanitizedValue))), schema) |
63 | | - |
64 | | - val parsedWithPattern = patternOpt.flatMap { pattern => |
65 | | - val parsedCol = dataType match { |
66 | | - case TimestampType =>to_timestamp(df.col("value"), pattern) |
67 | | - case DateType => to_date(df.col("value"), pattern) |
68 | | - case _ => df.col("value").cast(dataType) |
69 | | - } |
70 | | - val result = df.select(parsedCol.alias("parsed")).first().get(0) |
71 | | - if (result != null) Some(sanitizedValue) else None |
72 | | - } |
73 | | - |
74 | | - if (parsedWithPattern.isDefined) { |
75 | | - parsedWithPattern.get |
76 | | - } else { |
77 | | - val isoPattern = dataType match { |
78 | | - case TimestampType => "yyyy-MM-dd'T'HH:mm:ss.SSSSSS" |
79 | | - case DateType => "yyyy-MM-dd" |
80 | | - case _ => "" |
81 | | - } |
82 | | - val parsedWithISO = dataType match { |
83 | | - case TimestampType => df.select(to_timestamp(df.col("value"), isoPattern)).alias("parsed").first().getAs[Timestamp](0) |
84 | | - case DateType => df.select(to_date(df.col("value"), isoPattern)).alias("parsed").first().getAs[Date](0) |
85 | | - case _ => null |
86 | | - } |
87 | | - if (parsedWithISO != null) { |
88 | | - patternOpt.getOrElse(isoPattern) match { |
89 | | - case pattern => |
90 | | - val dateFormat = new SimpleDateFormat(pattern) |
91 | | - dateFormat.format(parsedWithISO) |
92 | | - } |
93 | | - } else{ |
94 | | - throw new IllegalArgumentException(s"Invalid infinity value: '$value' for type: $dataType with pattern ${patternOpt.getOrElse("none")} and ISO fallback ($isoPattern)") |
95 | | - } |
96 | | - } |
97 | | - } |
98 | | - |
99 | | - |
100 | | - def replaceInfinitySymbols(column: Column)(implicit spark:SparkSession, defaults: TypeDefaults): Column = { |
101 | | - var resultCol = column.cast(StringType) |
| 43 | + protected val targetType: DataType |
102 | 44 |
|
103 | | - val validatedMinus = if (origType == DateType || origType == TimestampType) { |
104 | | - infMinusValue.map( v => validateAndConvertInfinityValue(v, origType, getPattern(origType),spark)) |
105 | | - } else { |
106 | | - infMinusValue.map(sanitizeInput) |
107 | | - } |
108 | | - |
109 | | - val validatedPlus = if (origType == DateType || origType == TimestampType){ |
110 | | - infPlusValue.map(v => validateAndConvertInfinityValue(v, origType, getPattern(origType),spark)) |
111 | | - } else{ |
112 | | - infPlusValue.map(sanitizeInput) |
113 | | - } |
| 45 | + def replaceInfinitySymbols(column: Column): Column = { |
| 46 | + targetType match { |
| 47 | + case DateType => |
| 48 | + val defaultDatePattern = "yyyy-MM-dd" |
| 49 | + val minusDate = infMinusValue.flatMap { value => |
| 50 | + infMinusSymbol.map { symbol => |
| 51 | + when( |
| 52 | + column === lit(symbol).cast(origType), |
| 53 | + lit(parseInfinityValue(value, infMinusPattern.getOrElse(defaultDatePattern)).getTime) |
| 54 | + .cast(TimestampType) |
| 55 | + .cast(DateType) |
| 56 | + ) |
| 57 | + } |
| 58 | + }.getOrElse(column) |
| 59 | + |
| 60 | + infPlusValue.flatMap { value => |
| 61 | + infPlusSymbol.map { symbol => |
| 62 | + when( |
| 63 | + minusDate === lit(symbol).cast(origType), |
| 64 | + lit(parseInfinityValue(value, infPlusPattern.getOrElse(defaultDatePattern)).getTime) |
| 65 | + .cast(TimestampType) |
| 66 | + .cast(DateType) |
| 67 | + ).otherwise(minusDate) |
| 68 | + } |
| 69 | + }.getOrElse(minusDate) |
114 | 70 |
|
115 | | - validatedMinus.foreach { v => |
116 | | - infMinusSymbol.foreach { s => |
117 | | - resultCol = when(resultCol === lit(s), lit(v)).otherwise(resultCol) |
118 | | - } |
119 | | - } |
| 71 | + case TimestampType => |
| 72 | + val defaultTimestampPattern = "yyyy-MM-dd HH:mm:ss" |
| 73 | + val minusTimestamp = infMinusValue.flatMap { value => |
| 74 | + infMinusSymbol.map { symbol => |
| 75 | + when( |
| 76 | + column === lit(symbol).cast(origType), |
| 77 | + lit(parseInfinityValue(value, infMinusPattern.getOrElse(defaultTimestampPattern)).getTime) |
| 78 | + .cast(TimestampType) |
| 79 | + ) |
| 80 | + } |
| 81 | + }.getOrElse(column) |
| 82 | + |
| 83 | + infPlusValue.flatMap { value => |
| 84 | + infPlusSymbol.map { symbol => |
| 85 | + when( |
| 86 | + minusTimestamp === lit(symbol).cast(origType), |
| 87 | + lit(parseInfinityValue(value, infPlusPattern.getOrElse(defaultTimestampPattern)).getTime) |
| 88 | + .cast(TimestampType) |
| 89 | + ).otherwise(minusTimestamp) |
| 90 | + } |
| 91 | + }.getOrElse(minusTimestamp) |
120 | 92 |
|
121 | | - validatedPlus.foreach { v => |
122 | | - infPlusSymbol.foreach { s => |
123 | | - resultCol = when(resultCol === lit(s), lit(v)).otherwise(resultCol) |
| 93 | + case _ => |
| 94 | + val columnWithNegativeInf: Column = infMinusSymbol.flatMap { minusSymbol => |
| 95 | + infMinusValue.map { minusValue => |
| 96 | + when(column === lit(minusSymbol).cast(origType), lit(minusValue).cast(origType)).otherwise(column) |
| 97 | + } |
| 98 | + }.getOrElse(column) |
| 99 | + |
| 100 | + infPlusSymbol.flatMap { plusSymbol => |
| 101 | + infPlusValue.map { plusValue => |
| 102 | + when(columnWithNegativeInf === lit(plusSymbol).cast(origType), lit(plusValue).cast(origType)) |
| 103 | + .otherwise(columnWithNegativeInf) |
| 104 | + } |
| 105 | + }.getOrElse(columnWithNegativeInf) |
124 | 106 | } |
125 | | - } |
| 107 | + } |
126 | 108 |
|
127 | | - origType match { |
128 | | - case TimestampType => |
129 | | - val pattern = getPattern(origType).getOrElse( |
130 | | - defaults.defaultTimestampTimeZone.map(_ => "yyyy-MM-dd'T'HH:mm:ss.SSSSSS").getOrElse("yyyy-MM-dd HH:mm:ss") |
131 | | - ) |
132 | | - coalesce( |
133 | | - to_timestamp(resultCol,pattern), |
134 | | - to_timestamp(resultCol,"yyyy-MM-dd'T'HH:mm:ss.SSSSSS") |
135 | | - ).cast(origType) |
136 | | - case DateType => |
137 | | - val pattern = getPattern(origType).getOrElse( |
138 | | - defaults.defaultDateTimeZone.map(_ => "yyyy-MM-dd").getOrElse("yyyy-MM-dd") |
139 | | - ) |
140 | | - coalesce( |
141 | | - to_date(resultCol,pattern), |
142 | | - to_date(resultCol, "yyyy-MM-dd") |
143 | | - ).cast(origType) |
144 | | - case _ => |
145 | | - resultCol.cast(origType) |
146 | | - } |
| 109 | + private def parseInfinityValue(value: String, pattern: String): Date = { |
| 110 | + val dateFormat = new SimpleDateFormat(pattern, Locale.US) |
| 111 | + dateFormat.setLenient(false) |
| 112 | + new Date(dateFormat.parse(value).getTime) |
147 | 113 | } |
148 | 114 | } |
| 115 | + |
| 116 | + |
0 commit comments