|
| 1 | +package org.locationtech.rasterframes.encoders |
| 2 | + |
| 3 | +import frameless.{RecordEncoderField, TypedEncoder} |
| 4 | +import org.apache.spark.sql.FramelessInternals |
| 5 | +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, InvokeLike, NewInstance, StaticInvoke} |
| 6 | +import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal} |
| 7 | +import org.apache.spark.sql.types.{DataType, Metadata, StructField, StructType} |
| 8 | + |
| 9 | +import scala.reflect.{ClassTag, classTag} |
| 10 | + |
| 11 | +/** Can be useful for non Scala types and for complicated case classes with implicits in the constructor. */ |
| 12 | +object ManualTypedEncoder { |
| 13 | + /** Invokes apply from the companion object. */ |
| 14 | + def staticInvoke[T: ClassTag]( |
| 15 | + fields: List[RecordEncoderField], |
| 16 | + fieldNameModify: String => String = identity, |
| 17 | + isNullable: Boolean = true |
| 18 | + ): TypedEncoder[T] = apply[T](fields, { (classTag, newArgs, jvmRepr) => StaticInvoke(classTag.runtimeClass, jvmRepr, "apply", newArgs, propagateNull = true, returnNullable = false) }, fieldNameModify, isNullable) |
| 19 | + |
| 20 | + /** Invokes object constructor. */ |
| 21 | + def newInstance[T: ClassTag]( |
| 22 | + fields: List[RecordEncoderField], |
| 23 | + fieldNameModify: String => String = identity, |
| 24 | + isNullable: Boolean = true |
| 25 | + ): TypedEncoder[T] = apply[T](fields, { (classTag, newArgs, jvmRepr) => NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true) }, fieldNameModify, isNullable) |
| 26 | + |
| 27 | + def apply[T: ClassTag]( |
| 28 | + fields: List[RecordEncoderField], |
| 29 | + newInstanceExpression: (ClassTag[T], Seq[Expression], DataType) => InvokeLike, |
| 30 | + fieldNameModify: String => String = identity, |
| 31 | + isNullable: Boolean = true |
| 32 | + ): TypedEncoder[T] = make[T](fields, newInstanceExpression, fieldNameModify, isNullable, classTag[T]) |
| 33 | + |
| 34 | + private def make[T]( |
| 35 | + // the catalyst struct |
| 36 | + fields: List[RecordEncoderField], |
| 37 | + // newInstanceExpression for the fromCatalyst function |
| 38 | + newInstanceExpression: (ClassTag[T], Seq[Expression], DataType) => InvokeLike, |
| 39 | + // allows to convert the field name into the field name getter |
| 40 | + fieldNameModify: String => String, |
| 41 | + // is the codec nullable |
| 42 | + isNullable: Boolean, |
| 43 | + // ClassTag is required for the TypedEncoder constructor |
| 44 | + // it is passed explicitly to disambiguate ClassTag passed implicitly as a function argument |
| 45 | + // and the one from the TypedEncoder constructor |
| 46 | + ct: ClassTag[T] |
| 47 | + ): TypedEncoder[T] = new TypedEncoder[T]()(ct) { |
| 48 | + def nullable: Boolean = isNullable |
| 49 | + |
| 50 | + def jvmRepr: DataType = FramelessInternals.objectTypeFor[T] |
| 51 | + |
| 52 | + def catalystRepr: DataType = { |
| 53 | + val structFields = fields.map { field => |
| 54 | + StructField( |
| 55 | + name = field.name, |
| 56 | + dataType = field.encoder.catalystRepr, |
| 57 | + nullable = field.encoder.nullable, |
| 58 | + metadata = Metadata.empty |
| 59 | + ) |
| 60 | + } |
| 61 | + |
| 62 | + StructType(structFields) |
| 63 | + } |
| 64 | + |
| 65 | + def fromCatalyst(path: Expression): Expression = { |
| 66 | + val newArgs: Seq[Expression] = fields.map { field => |
| 67 | + field.encoder.fromCatalyst( GetStructField(path, field.ordinal, Some(field.name)) ) |
| 68 | + } |
| 69 | + val newExpr = newInstanceExpression(classTag, newArgs, jvmRepr) |
| 70 | + |
| 71 | + val nullExpr = Literal.create(null, jvmRepr) |
| 72 | + If(IsNull(path), nullExpr, newExpr) |
| 73 | + } |
| 74 | + |
| 75 | + def toCatalyst(path: Expression): Expression = { |
| 76 | + val nameExprs = fields.map { field => Literal(field.name) } |
| 77 | + |
| 78 | + val valueExprs: Seq[Expression] = fields.map { field => |
| 79 | + val fieldPath = Invoke(path, fieldNameModify(field.name), field.encoder.jvmRepr, Nil) |
| 80 | + field.encoder.toCatalyst(fieldPath) |
| 81 | + } |
| 82 | + |
| 83 | + // the way exprs are encoded in CreateNamedStruct |
| 84 | + val exprs = nameExprs.zip(valueExprs).flatMap { case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil } |
| 85 | + |
| 86 | + val createExpr = CreateNamedStruct(exprs) |
| 87 | + val nullExpr = Literal.create(null, createExpr.dataType) |
| 88 | + If(IsNull(path), nullExpr, createExpr) |
| 89 | + } |
| 90 | + } |
| 91 | +} |
0 commit comments