Skip to content

Commit eefe87a

Browse files
authored
#38: ArrayTransformation might fail if column name us a number (#39)
* Upgrade to spark-commons 0.4.0 * method `ArrayTransformations.arrCol` replaced by _Spark Commons_' `col_of_path` function.
1 parent 38da9e2 commit eefe87a

File tree

3 files changed

+7
-17
lines changed

3 files changed

+7
-17
lines changed

project/Dependencies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ object Dependencies {
3333
List(
3434
"org.apache.spark" %% "spark-core" % sparkVersion % Provided,
3535
"org.apache.spark" %% "spark-sql" % sparkVersion % Provided,
36-
"za.co.absa" %% s"spark-commons-spark$sparkVersionUpToMinor" % "0.3.2" % Provided,
37-
"za.co.absa" %% "spark-commons-test" % "0.3.2" % Test,
36+
"za.co.absa" %% s"spark-commons-spark$sparkVersionUpToMinor" % "0.4.0" % Provided,
37+
"za.co.absa" %% "spark-commons-test" % "0.4.0" % Test,
3838
"com.typesafe" % "config" % "1.4.1",
3939
"com.github.mrpowers" %% "spark-fast-tests" % sparkFastTestsVersion(scalaVersion) % Test,
4040
"org.scalatest" %% "scalatest" % "3.2.2" % Test

src/main/scala/za/co/absa/standardization/ArrayTransformations.scala

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.{Column, Dataset, Row, SparkSession}
2323
import org.slf4j.LoggerFactory
2424
import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements
2525
import za.co.absa.spark.commons.utils.SchemaUtils
26+
import za.co.absa.spark.commons.sql.functions.col_of_path
2627

2728
object ArrayTransformations {
2829
private val logger = LoggerFactory.getLogger(this.getClass)
@@ -61,7 +62,7 @@ object ArrayTransformations {
6162
column as tokens.head
6263
} // some other attribute
6364
else if (!columnName.startsWith(currPath)) {
64-
arrCol(currPath)
65+
col_of_path(currPath)
6566
} // partial match, keep going
6667
else if (topType.isEmpty) {
6768
struct(helper(tokens.tail, pathAcc ++ List(tokens.head))) as tokens.head
@@ -76,23 +77,12 @@ object ArrayTransformations {
7677
}
7778
struct(fields.map(field => helper((List(field) ++ tokens.tail).distinct, pathAcc :+ tokens.head) as field): _*) as tokens.head
7879
case _: ArrayType => throw new IllegalStateException("Cannot reconstruct array columns. Please use this within arrayTransform.")
79-
case _: DataType => arrCol(currPath) as tokens.head
80+
case _: DataType => col_of_path(currPath) as tokens.head
8081
}
8182
}
8283
}
8384

8485
ds.withColumn(toks.head, helper(toks, Seq()))
8586
}
8687

87-
def arrCol(any: String): Column = {
88-
val toks = any.replaceAll("\\[(\\d+)\\]", "\\.$1").split("\\.")
89-
toks.tail.foldLeft(col(toks.head)){
90-
case (acc, tok) =>
91-
if (tok.matches("\\d+")) {
92-
acc(tok.toInt)
93-
} else {
94-
acc(tok)
95-
}
96-
}
97-
}
9888
}

src/test/scala/za/co/absa/standardization/ArrayTransformationsSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ class ArrayTransformationsSuite extends AnyFunSuite with SparkTestBase {
7979

8080
val res = ArrayTransformations.flattenArrays(df, "a")
8181

82-
val exp = List(
82+
val exp = Seq(
8383
Nested1Level(List(Some(1), None, Some(2), Some(3), Some(4), Some(5), Some(6))),
8484
Nested1Level(List()),
85-
Nested1Level(null)).toSeq
85+
Nested1Level(null))
8686

8787
val resLocal = res.as[Nested1Level].collect().toSeq
8888

0 commit comments

Comments
 (0)