Skip to content
This repository was archived by the owner on Aug 31, 2021. It is now read-only.

Commit 483f19b

Browse files
committed
Fixed issue #76. See PR #87
1 parent 03217f7 commit 483f19b

File tree

3 files changed

+40
-34
lines changed

3 files changed

+40
-34
lines changed

src/main/scala/com/audienceproject/spark/dynamodb/implicits.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import com.audienceproject.spark.dynamodb.reflect.SchemaAnalysis
2424
import org.apache.spark.sql._
2525
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2626
import org.apache.spark.sql.functions.col
27-
import org.apache.spark.sql.types.StructField
2827

2928
import scala.reflect.ClassTag
3029
import scala.reflect.runtime.universe.TypeTag
@@ -41,28 +40,29 @@ object implicits {
4140

4241
def dynamodbAs[T <: Product : ClassTag : TypeTag](tableName: String): Dataset[T] = {
4342
implicit val encoder: Encoder[T] = ExpressionEncoder()
44-
getColumnsAlias(getDynamoDBSource(tableName)
45-
.schema(SchemaAnalysis[T]).load()).as
43+
val (schema, aliasMap) = SchemaAnalysis[T]
44+
getColumnsAlias(getDynamoDBSource(tableName).schema(schema).load(), aliasMap).as
4645
}
4746

4847
def dynamodbAs[T <: Product : ClassTag : TypeTag](tableName: String, indexName: String): Dataset[T] = {
4948
implicit val encoder: Encoder[T] = ExpressionEncoder()
50-
getColumnsAlias(getDynamoDBSource(tableName)
51-
.option("indexName", indexName)
52-
.schema(SchemaAnalysis[T]).load()).as
49+
val (schema, aliasMap) = SchemaAnalysis[T]
50+
getColumnsAlias(
51+
getDynamoDBSource(tableName).option("indexName", indexName).schema(schema).load(), aliasMap).as
5352
}
5453

5554
private def getDynamoDBSource(tableName: String): DataFrameReader =
5655
reader.format("com.audienceproject.spark.dynamodb.datasource").option("tableName", tableName)
5756

58-
private def getColumnsAlias(dataFrame: DataFrame): DataFrame = {
59-
val columnsAlias = dataFrame.schema.collect({
60-
case StructField(name, _, _, metadata) if metadata.contains("alias") =>
61-
col(name).as(metadata.getString("alias"))
62-
case StructField(name, _, _, _) =>
63-
col(name)
64-
})
65-
dataFrame.select(columnsAlias: _*)
57+
private def getColumnsAlias(dataFrame: DataFrame, aliasMap: Map[String, String]): DataFrame = {
58+
if (aliasMap.isEmpty) dataFrame
59+
else {
60+
val columnsAlias = dataFrame.columns.map({
61+
case name if aliasMap.isDefinedAt(name) => col(name) as aliasMap(name)
62+
case name => col(name)
63+
})
64+
dataFrame.select(columnsAlias: _*)
65+
}
6666
}
6767

6868
}

src/main/scala/com/audienceproject/spark/dynamodb/reflect/SchemaAnalysis.scala

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package com.audienceproject.spark.dynamodb.reflect
2222

2323
import com.audienceproject.spark.dynamodb.attribute
2424
import org.apache.spark.sql.catalyst.ScalaReflection
25-
import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField, StructType}
25+
import org.apache.spark.sql.types.{StructField, StructType}
2626

2727
import scala.reflect.ClassTag
2828
import scala.reflect.runtime.{universe => ru}
@@ -32,35 +32,37 @@ import scala.reflect.runtime.{universe => ru}
3232
*/
3333
private[dynamodb] object SchemaAnalysis {
3434

35-
def apply[T <: Product : ClassTag : ru.TypeTag]: StructType = {
35+
def apply[T <: Product : ClassTag : ru.TypeTag]: (StructType, Map[String, String]) = {
3636

3737
val runtimeMirror = ru.runtimeMirror(getClass.getClassLoader)
3838

3939
val classObj = scala.reflect.classTag[T].runtimeClass
4040
val classSymbol = runtimeMirror.classSymbol(classObj)
4141

42-
val sparkFields = classSymbol.primaryConstructor.typeSignature.paramLists.head.map(field => {
43-
val sparkType = ScalaReflection.schemaFor(field.typeSignature).dataType
42+
val params = classSymbol.primaryConstructor.typeSignature.paramLists.head
43+
val (sparkFields, aliasMap) = params.foldLeft((List.empty[StructField], Map.empty[String, String]))({
44+
case ((list, map), field) =>
45+
val sparkType = ScalaReflection.schemaFor(field.typeSignature).dataType
4446

45-
// Black magic from here:
46-
// https://stackoverflow.com/questions/23046958/accessing-an-annotation-value-in-scala
47-
val attrName = field.annotations.collectFirst({
48-
case ann: ru.AnnotationApi if ann.tree.tpe =:= ru.typeOf[attribute] =>
49-
ann.tree.children.tail.collectFirst({
50-
case ru.Literal(ru.Constant(name: String)) => name
51-
})
52-
}).flatten
47+
// Black magic from here:
48+
// https://stackoverflow.com/questions/23046958/accessing-an-annotation-value-in-scala
49+
val attrName = field.annotations.collectFirst({
50+
case ann: ru.AnnotationApi if ann.tree.tpe =:= ru.typeOf[attribute] =>
51+
ann.tree.children.tail.collectFirst({
52+
case ru.Literal(ru.Constant(name: String)) => name
53+
})
54+
}).flatten
5355

54-
if (attrName.isDefined) {
55-
val metadata = new MetadataBuilder().putString("alias", field.name.toString).build()
56-
StructField(attrName.get, sparkType, nullable = true, metadata)
57-
} else {
58-
StructField(field.name.toString, sparkType, nullable = true, Metadata.empty)
59-
}
56+
if (attrName.isDefined) {
57+
val sparkField = StructField(attrName.get, sparkType, nullable = true)
58+
(list :+ sparkField, map + (attrName.get -> field.name.toString))
59+
} else {
60+
val sparkField = StructField(field.name.toString, sparkType, nullable = true)
61+
(list :+ sparkField, map)
62+
}
6063
})
6164

62-
StructType(sparkFields)
63-
65+
(StructType(sparkFields), aliasMap)
6466
}
6567

6668
}

src/test/scala/com/audienceproject/spark/dynamodb/DefaultSourceTest.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ class DefaultSourceTest extends AbstractInMemoryTest {
5757

5858
test("Test of attribute name alias") {
5959
import spark.implicits._
60+
spark.read.dynamodb("TestFruit").printSchema()
61+
spark.read.dynamodb("TestFruit").show()
62+
spark.read.dynamodbAs[TestFruit]("TestFruit").printSchema()
63+
spark.read.dynamodbAs[TestFruit]("TestFruit").show()
6064
val itemApple = spark.read.dynamodbAs[TestFruit]("TestFruit")
6165
.filter($"primaryKey" === "apple")
6266
.takeAsList(1).get(0)

0 commit comments

Comments
 (0)