Skip to content

Commit 71dc69a

Browse files
infinity value sanitized , input pattern validation and ISO fallback
1 parent a13a54d commit 71dc69a

File tree

2 files changed

+334
-13
lines changed

2 files changed

+334
-13
lines changed

src/main/scala/za/co/absa/standardization/stages/InfinitySupport.scala

Lines changed: 121 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,137 @@
1616

1717
package za.co.absa.standardization.stages
1818

19-
import org.apache.spark.sql.Column
20-
import org.apache.spark.sql.functions.{lit, when}
21-
import org.apache.spark.sql.types.DataType
19+
import org.apache.spark.sql.{Column, SparkSession, Row, functions => F}
20+
import org.apache.spark.sql.types.{DataType, DateType, StringType, TimestampType, StructType, StructField}
21+
import za.co.absa.standardization.types.{TypeDefaults, TypedStructField}
22+
import za.co.absa.standardization.types.parsers.DateTimeParser
23+
import za.co.absa.standardization.time.DateTimePattern
24+
import za.co.absa.standardization.types.TypedStructField.DateTimeTypeStructField
25+
import java.sql.Timestamp
26+
import scala.collection.JavaConverters._
27+
import java.text.SimpleDateFormat
28+
import java.util.Date
29+
30+
2231

2332
trait InfinitySupport {
2433
protected def infMinusSymbol: Option[String]
2534
protected def infMinusValue: Option[String]
2635
protected def infPlusSymbol: Option[String]
2736
protected def infPlusValue: Option[String]
2837
protected val origType: DataType
38+
protected def defaults: TypeDefaults
39+
protected def field: TypedStructField
40+
protected def spark: SparkSession
2941

30-
def replaceInfinitySymbols(column: Column): Column = {
31-
val columnWithNegativeInf: Column = infMinusSymbol.flatMap { minusSymbol =>
32-
infMinusValue.map { minusValue =>
33-
when(column === lit(minusSymbol).cast(origType), lit(minusValue).cast(origType)).otherwise(column)
42+
43+
private def sanitizeInput(s: String): String = {
44+
if (s.matches("[a-zA-Z0-9:.-]+")) s
45+
else {
46+
throw new IllegalArgumentException(s"Invalid input '$s': must be alphanumeric , colon, dot or hyphen")
47+
}
48+
}
49+
50+
private def getPattern(dataType: DataType): Option[String] = {
51+
dataType match {
52+
case DateType | TimestampType =>
53+
field match {
54+
case dateField: DateTimeTypeStructField[_] =>
55+
dateField.pattern.toOption.flatten.map(_.pattern)
56+
case _ => None
57+
}
58+
case _ => None
59+
}
60+
}
61+
62+
private def validateAndConvertInfinityValue(value: String, dataType: DataType, patternOpt: Option[String]): String = {
63+
val sanitizedValue = sanitizeInput(value)
64+
val schema = StructType(Seq(StructField("value", StringType, nullable = false)))
65+
val df = spark.createDataFrame(spark.sparkContext.parallelize(Seq(Row(sanitizedValue))), schema)
66+
67+
val parsedWithPattern = patternOpt.flatMap { pattern =>
68+
val parsedCol = dataType match {
69+
case TimestampType => F.to_timestamp(F.col("value"), pattern)
70+
case DateType => F.to_date(F.col("value"), pattern)
71+
case _ => F.col("value").cast(dataType)
3472
}
35-
}.getOrElse(column)
73+
val result = df.select(parsedCol.alias("parsed")).first().get(0)
74+
if (result != null) Some(sanitizedValue) else None
75+
}
3676

37-
infPlusSymbol.flatMap { plusSymbol =>
38-
infPlusValue.map { plusValue =>
39-
when(columnWithNegativeInf === lit(plusSymbol).cast(origType), lit(plusValue).cast(origType))
40-
.otherwise(columnWithNegativeInf)
77+
if (parsedWithPattern.isDefined) {
78+
parsedWithPattern.get
79+
} else {
80+
val isoPattern = dataType match {
81+
case TimestampType => "yyyy-MM-dd'T'HH:mm:ss.SSSSSS"
82+
case DateType => "yyyy-MM-dd"
83+
case _ => ""
84+
}
85+
val parsedWithISO = dataType match {
86+
case TimestampType => df.select(F.to_timestamp(F.col("value"), isoPattern)).alias("parsed").first().getAs[Timestamp](0)
87+
case DateType => df.select(F.to_date(F.col("value"), isoPattern)).alias("parsed").first().getAs[Date](0)
88+
case _ => null
89+
}
90+
if (parsedWithISO != null) {
91+
patternOpt.getOrElse(isoPattern) match {
92+
case pattern =>
93+
val dateFormat = new SimpleDateFormat(pattern)
94+
dateFormat.format(parsedWithISO)
95+
}
96+
} else{
97+
throw new IllegalArgumentException(s"Invalid infinity value: '$value' for type: $dataType with pattern ${patternOpt.getOrElse("none")} and ISO fallback ($isoPattern")
4198
}
42-
}.getOrElse(columnWithNegativeInf)
99+
}
100+
}
101+
102+
protected val validatedInfMinusValue: Option[String] = if (origType == DateType || origType == TimestampType) {
103+
infMinusValue.map { v =>
104+
validateAndConvertInfinityValue(v, origType,getPattern(origType))
105+
}
106+
} else {
107+
infMinusValue.map(sanitizeInput)
108+
}
109+
110+
protected val validatedInfPlusValue: Option[String] = if (origType == DateType || origType == TimestampType) {
111+
infPlusValue.map { v =>
112+
validateAndConvertInfinityValue(v, origType,getPattern(origType))
113+
}
114+
} else {
115+
infPlusValue.map(sanitizeInput)
116+
}
117+
118+
def replaceInfinitySymbols(column: Column): Column = {
119+
var resultCol = column.cast(StringType)
120+
validatedInfMinusValue.foreach { v =>
121+
infMinusSymbol.foreach { s =>
122+
resultCol = F.when(resultCol === F.lit(s), F.lit(v)).otherwise(resultCol)
123+
}
124+
}
125+
validatedInfPlusValue.foreach { v =>
126+
infPlusSymbol.foreach { s =>
127+
resultCol = F.when(resultCol === F.lit(s), F.lit(v)).otherwise(resultCol)
128+
}
129+
}
130+
131+
origType match {
132+
case TimestampType =>
133+
val pattern = getPattern(origType).getOrElse(
134+
defaults.defaultTimestampTimeZone.map(_ => "yyyy-MM-dd'T'HH:mm:ss.SSSSSS").getOrElse("yyyy-MM-dd HH:mm:ss")
135+
)
136+
F.coalesce(
137+
F.to_timestamp(resultCol,pattern),
138+
F.to_timestamp(resultCol,"yyyy-MM-dd'T'HH:mm:ss.SSSSSS")
139+
).cast(origType)
140+
case DateType =>
141+
val pattern = getPattern(origType).getOrElse(
142+
defaults.defaultDateTimeZone.map(_ => "yyyy-MM-dd").getOrElse("yyyy-MM-dd")
143+
)
144+
F.coalesce(
145+
F.to_date(resultCol,pattern),
146+
F.to_date(resultCol, "yyyy-MM-dd")
147+
).cast(origType)
148+
case _ =>
149+
resultCol.cast(origType)
150+
}
43151
}
44152
}
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
package za.co.absa.standardization
2+
3+
4+
import org.scalatest.BeforeAndAfterAll
5+
import org.scalatest.funsuite.AnyFunSuite
6+
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
7+
import org.apache.spark.sql.types.{DataType, DateType, Metadata, StringType, StructField, StructType, TimestampType}
8+
import org.apache.spark.sql.functions.{col, lit, to_date, to_timestamp, when}
9+
import java.sql
10+
import java.sql.{Date, Timestamp}
11+
import java.text.{SimpleDateFormat,ParseException}
12+
import java.util.TimeZone
13+
import scala.util.Try
14+
15+
16+
class InfinitySupportIsoTest extends AnyFunSuite with BeforeAndAfterAll {
17+
var sparkSession: SparkSession = _
18+
19+
override def beforeAll(): Unit = {
20+
sparkSession = SparkSession.builder()
21+
.appName("InfinityISOTest")
22+
.master("local[*]")
23+
.config("spark.sql.legacy.parquet.datetimeRebaseModeInRead", "CORRECTED")
24+
.config("spark.sql.legacy.parquet.datetimeRebaseModeInWrite", "CORRECTED")
25+
.getOrCreate()
26+
}
27+
28+
override def afterAll(): Unit ={
29+
if (sparkSession != null) {
30+
sparkSession.stop()
31+
}
32+
}
33+
34+
private def createTestDataFrame(data: Seq[String]): DataFrame = {
35+
sparkSession.createDataFrame(
36+
sparkSession.sparkContext.parallelize(data.map(Row(_))),
37+
StructType(Seq(StructField("value", StringType, nullable = false)))
38+
)
39+
}
40+
41+
private val configString =
42+
"""
43+
standardization.infinity {
44+
minus.symbol = "-inf"
45+
minus.value = "1970-01-01 00:00:00.000000"
46+
plus.symbol = "inf"
47+
plus.value ="9999-12-31 23:59:59.999999"
48+
}
49+
"""
50+
private def replaceInfinitySymbols(column: Column, dataType: DataType, pattern: Option[String], timezone:String, minusSymbol: String, minusValue: String, plusSymbol:String, plusValue:String): Column ={
51+
def validateValue(value: String, patternOpt: Option[String], dataType: DataType) : Unit = {
52+
val isoPattern = dataType match {
53+
case TimestampType => "yyyy-MM-dd'T'HH:mm:ss.SSSSSS"
54+
case DateType => "yyyy-MM-dd"
55+
case _ => throw new IllegalArgumentException(s"Unsupported data type: $dataType")
56+
}
57+
58+
val formatsToTry = patternOpt.toSeq ++ Seq(isoPattern)
59+
var lastException: Option[ParseException] = None
60+
61+
for (fmt <- formatsToTry) {
62+
try {
63+
val sdf = new SimpleDateFormat(fmt)
64+
sdf.setTimeZone(TimeZone.getTimeZone(timezone))
65+
sdf.setLenient(false)
66+
sdf.parse(value)
67+
return
68+
} catch {
69+
case e: ParseException => lastException = Some(e)
70+
}
71+
}
72+
73+
val errorMsg = s"Invalid infinity value: '$value' for type: ${dataType.toString.toLowerCase} with pattern ${patternOpt.getOrElse("none")} and ISO fallback ($isoPattern)"
74+
throw new IllegalArgumentException(errorMsg,lastException.orNull)
75+
}
76+
77+
validateValue(minusValue,pattern,dataType)
78+
validateValue(plusValue, pattern, dataType)
79+
80+
81+
dataType match {
82+
case TimestampType =>
83+
when(col(column.toString) === minusSymbol, lit(minusValue))
84+
.when(col(column.toString) === plusSymbol, lit(plusValue))
85+
.otherwise(
86+
pattern.map(p => to_timestamp(col(column.toString),p))
87+
.getOrElse(to_timestamp(col(column.toString)))
88+
)
89+
case DateType =>
90+
when(col(column.toString) === minusSymbol, lit(minusValue))
91+
.when(col(column.toString) === plusSymbol, lit(plusValue))
92+
.otherwise(
93+
pattern.map( p => to_date(col(column.toString), p))
94+
.getOrElse(to_date(col(column.toString)))
95+
)
96+
case _ => throw new IllegalArgumentException(s"Unsupported data type: $dataType")
97+
}
98+
}
99+
100+
101+
test("Replace infinity symbols for timestamp with valid pattern"){
102+
val df = createTestDataFrame(Seq("-inf","inf", "2025-07-05 12:34:56", null))
103+
val result = df.withColumn("result", replaceInfinitySymbols(col("value"), TimestampType,Some("yyyy-MM-dd HH:mm:ss"), "UTC","-inf", "1970-01-01 00:00:00","inf","9999-12-31 23:59:59"))
104+
.select("result")
105+
.collect()
106+
.map(_.getAs[TimestampType](0))
107+
108+
val expected = Seq(
109+
Timestamp.valueOf("1970-01-01 00:00:00"),
110+
Timestamp.valueOf("9999-12-31 23:59:59"),
111+
Timestamp.valueOf("2025-07-05 12:34:56"),
112+
null
113+
)
114+
115+
assert(result sameElements expected)
116+
}
117+
118+
test("Convert invalid timestamp pattern to ISO"){
119+
val df = createTestDataFrame(Seq("-inf","inf"))
120+
val result = df.withColumn("result", replaceInfinitySymbols(
121+
col("value"),
122+
TimestampType,
123+
Some("yyyy-MM-dd HH:mm:ss"),
124+
"UTC",
125+
"-inf",
126+
"1970-01-01 00:00:00",
127+
"inf",
128+
"9999-12-31 23:59:59"))
129+
.select("result")
130+
.collect()
131+
.map(_.getAs[TimestampType] (0))
132+
133+
134+
val expected = Seq(
135+
Timestamp.valueOf("1970-01-01 00:00:00"),
136+
Timestamp.valueOf("9999-12-31 23:59:59")
137+
)
138+
139+
assert (result sameElements expected)
140+
}
141+
142+
143+
test("Replace infinity symbol for date with valid pattern"){
144+
val df = createTestDataFrame(Seq("-inf", "inf", "20245-07-05",null))
145+
val result = df.withColumn("result", replaceInfinitySymbols(
146+
col("value"),
147+
DateType,
148+
Some("yyyy-MM-dd"),
149+
"UTC",
150+
"-inf",
151+
"1970-01-01",
152+
"inf",
153+
"9999-12-31"
154+
))
155+
.select("result")
156+
.collect()
157+
.map(_.getAs[Date](0))
158+
159+
val expected = Seq(
160+
Date.valueOf("1970-01-01"),
161+
Date.valueOf("9999-12-31"),
162+
Date.valueOf("2025-07-05"),
163+
null
164+
)
165+
166+
assert (result sameElements expected)
167+
}
168+
169+
170+
test("Throw error for unparseable infinity value"){
171+
val exception = intercept[IllegalArgumentException] {
172+
replaceInfinitySymbols(
173+
col("value"),
174+
TimestampType,
175+
Some("yyyy-MM-dd HH:mm:ss"),
176+
"UTC",
177+
"-inf",
178+
"invalid_date",
179+
"inf",
180+
"9999-12-31 23:59:59"
181+
)
182+
}
183+
184+
assert(exception.getMessage.contains("Invalid infinity value: 'invalid_date' for type: timestamp"))
185+
assert(exception.getMessage.contains("pattern yyyy-MM-dd:mm:ss"))
186+
assert(exception.getMessage.contains("ISO fallback (yyyy-MM-dd'T'HH:mm:ss.SSSSSS"))
187+
}
188+
189+
test("Handle missing pattern with ISO fallback"){
190+
val df = createTestDataFrame(Seq("-inf","inf"))
191+
val result = df.withColumn("result", replaceInfinitySymbols(
192+
col("value"),
193+
DateType,
194+
None,
195+
"UTC",
196+
"-inf",
197+
"1970-01-01",
198+
"inf",
199+
"9999-12-31"
200+
))
201+
.select("result")
202+
.collect()
203+
.map(_.getAs[Date](0))
204+
205+
206+
val expected = Seq(
207+
Date.valueOf("1970-01-01"),
208+
Date.valueOf("9999-12-31")
209+
)
210+
211+
assert (result sameElements expected)
212+
}
213+
}

0 commit comments

Comments
 (0)