Skip to content

Commit 68fdd3c

Browse files
authored
Merge pull request #22 from VirtusLab/conditional-columns
Support conditional columns with `when`
2 parents 75fbcee + a6ff6a9 commit 68fdd3c

File tree

13 files changed

+318
-196
lines changed

13 files changed

+318
-196
lines changed

src/main/ColumnOp.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ object ColumnOp:
1616
given numericNonNullable[T1 <: NumericType, T2 <: NumericType]: Plus[T1, T2] with
1717
type Out = DataType.CommonNumericNonNullableType[T1, T2]
1818
given numericNullable[T1 <: NumericOptType, T2 <: NumericOptType]: Plus[T1, T2] with
19-
type Out = DataType.CommonNumericOptType[T1, T2]
19+
type Out = DataType.CommonNumericNullableType[T1, T2]
2020

2121
trait Minus[T1 <: DataType, T2 <: DataType]:
2222
type Out <: DataType
@@ -25,7 +25,7 @@ object ColumnOp:
2525
given numericNonNullable[T1 <: NumericType, T2 <: NumericType]: Minus[T1, T2] with
2626
type Out = DataType.CommonNumericNonNullableType[T1, T2]
2727
given numericNullable[T1 <: NumericOptType, T2 <: NumericOptType]: Minus[T1, T2] with
28-
type Out = DataType.CommonNumericOptType[T1, T2]
28+
type Out = DataType.CommonNumericNullableType[T1, T2]
2929

3030
trait Mult[T1 <: DataType, T2 <: DataType]:
3131
type Out <: DataType
@@ -34,7 +34,7 @@ object ColumnOp:
3434
given numericNonNullable[T1 <: NumericType, T2 <: NumericType]: Mult[T1, T2] with
3535
type Out = DataType.CommonNumericNonNullableType[T1, T2]
3636
given numericNullable[T1 <: NumericOptType, T2 <: NumericOptType]: Mult[T1, T2] with
37-
type Out = DataType.CommonNumericOptType[T1, T2]
37+
type Out = DataType.CommonNumericNullableType[T1, T2]
3838

3939
trait Div[T1 <: DataType, T2 <: DataType]:
4040
type Out <: DataType

src/main/DataFrame.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package org.virtuslab.iskra
33
import org.apache.spark.sql
44
import org.apache.spark.sql.SparkSession
55
import scala.quoted.*
6-
import types.{DataType, StructType}
6+
import types.{Encoder, StructEncoder}
77

88
class DataFrame[Schema](val untyped: UntypedDataFrame):
99
type Alias
@@ -47,9 +47,9 @@ object DataFrame:
4747

4848
// TODO: Use only a subset of columns
4949
private def collectAsImpl[FrameSchema : Type, A : Type](df: Expr[DataFrame[FrameSchema]])(using Quotes): Expr[List[A]] =
50-
Expr.summon[DataType.Encoder[A]] match
50+
Expr.summon[Encoder[A]] match
5151
case Some(encoder) => encoder match
52-
case '{ $enc: DataType.StructEncoder[A] { type StructSchema = structSchema } } =>
52+
case '{ $enc: StructEncoder[A] { type StructSchema = structSchema } } =>
5353
Type.of[MacroHelpers.AsTuple[FrameSchema]] match
5454
case '[`structSchema`] =>
5555
'{ ${ df }.untyped.collect.toList.map(row => ${ enc }.decode(row).asInstanceOf[A]) }
@@ -58,7 +58,7 @@ object DataFrame:
5858
val structColumns = allColumns(Type.of[structSchema])
5959
val errorMsg = s"A data frame with columns:\n${showColumns(frameColumns)}\ncannot be collected as a list of ${Type.show[A]}, which would be encoded as a row with columns:\n${showColumns(structColumns)}"
6060
quotes.reflect.report.errorAndAbort(errorMsg)
61-
case '{ $enc: DataType.Encoder[A] { type ColumnType = colType } } =>
61+
case '{ $enc: Encoder[A] { type ColumnType = colType } } =>
6262
def fromDataType[T : Type] =
6363
Type.of[T] match
6464
case '[`colType`] =>

src/main/DataFrameBuilders.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,19 @@ import scala.quoted._
44
import org.apache.spark.sql
55
import org.apache.spark.sql.SparkSession
66
import org.virtuslab.iskra.DataFrame
7-
import org.virtuslab.iskra.types.{DataType, StructType}
8-
import DataType.{Encoder, StructEncoder, PrimitiveEncoder}
7+
import org.virtuslab.iskra.types.{DataType, StructType, Encoder, StructEncoder, PrimitiveEncoder}
98

109
object DataFrameBuilders:
1110
extension [A](seq: Seq[A])(using encoder: Encoder[A])
1211
transparent inline def toTypedDF(using spark: SparkSession): DataFrame[?] = ${ toTypedDFImpl('seq, 'encoder, 'spark) }
1312

1413
private def toTypedDFImpl[A : Type](seq: Expr[Seq[A]], encoder: Expr[Encoder[A]], spark: Expr[SparkSession])(using Quotes) =
1514
val (schemaType, schema, encodeFun) = encoder match
16-
case '{ $e: DataType.StructEncoder.Aux[A, t] } =>
15+
case '{ $e: StructEncoder.Aux[A, t] } =>
1716
val schema = '{ ${ e }.catalystType }
1817
val encodeFun: Expr[A => sql.Row] = '{ ${ e }.encode }
1918
(Type.of[t], schema, encodeFun)
20-
case '{ $e: DataType.Encoder.Aux[tpe, t] } =>
19+
case '{ $e: Encoder.Aux[tpe, t] } =>
2120
val schema = '{
2221
sql.types.StructType(Seq(
2322
sql.types.StructField("value", ${ encoder }.catalystType, ${ encoder }.isNullable )

src/main/UntypedOps.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
package org.virtuslab.iskra
22

33
import scala.quoted.*
4-
import types.{DataType, StructType}
4+
import types.{DataType, Encoder, StructType, StructEncoder}
55

66
object UntypedOps:
77
extension (untyped: UntypedColumn)
88
def typed[A <: DataType] = Column[A](untyped)
99

1010
extension (df: UntypedDataFrame)
11-
transparent inline def typed[A](using encoder: DataType.StructEncoder[A]): DataFrame[?] = ${ typedDataFrameImpl('df, 'encoder) } // TODO: Check schema at runtime? Check if names of columns match?
11+
transparent inline def typed[A](using encoder: StructEncoder[A]): DataFrame[?] = ${ typedDataFrameImpl('df, 'encoder) } // TODO: Check schema at runtime? Check if names of columns match?
1212

13-
private def typedDataFrameImpl[A : Type](df: Expr[UntypedDataFrame], encoder: Expr[DataType.StructEncoder[A]])(using Quotes) =
13+
private def typedDataFrameImpl[A : Type](df: Expr[UntypedDataFrame], encoder: Expr[StructEncoder[A]])(using Quotes) =
1414
encoder match
15-
case '{ ${e}: DataType.Encoder.Aux[tpe, StructType[t]] } =>
15+
case '{ ${e}: Encoder.Aux[tpe, StructType[t]] } =>
1616
'{ DataFrame[t](${ df }) }

src/main/When.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package org.virtuslab.iskra
2+
3+
import org.apache.spark.sql.{functions => f, Column => UntypedColumn}
4+
import org.virtuslab.iskra.types.{Coerce, DataType, BooleanOptType}
5+
6+
object When:
7+
class WhenColumn[T <: DataType](untyped: UntypedColumn) extends Column[DataType.Nullable[T]](untyped):
8+
def when[U <: DataType](condition: Column[BooleanOptType], value: Column[U])(using coerce: Coerce[T, U]): WhenColumn[coerce.Coerced] =
9+
WhenColumn(this.untyped.when(condition.untyped, value.untyped))
10+
def otherwise[U <: DataType](value: Column[U])(using coerce: Coerce[T, U]): Column[coerce.Coerced] =
11+
Column(this.untyped.otherwise(value.untyped))
12+
13+
def when[T <: DataType](condition: Column[BooleanOptType], value: Column[T]): WhenColumn[T] =
14+
WhenColumn(f.when(condition.untyped, value.untyped))

src/main/api/api.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ export org.virtuslab.iskra.$
2828
export org.virtuslab.iskra.{Column, DataFrame, UntypedColumn, UntypedDataFrame, :=, /}
2929

3030
object functions:
31-
export org.virtuslab.iskra.functions.lit
31+
export org.virtuslab.iskra.functions.{lit, when}
3232
export org.virtuslab.iskra.functions.Aggregates.*
3333

3434
export org.apache.spark.sql.SparkSession

src/main/functions/lit.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ package org.virtuslab.iskra.functions
22

33
import org.apache.spark.sql
44
import org.virtuslab.iskra.Column
5-
import org.virtuslab.iskra.types.DataType.PrimitiveEncoder
5+
import org.virtuslab.iskra.types.PrimitiveEncoder
66

77
def lit[A](value: A)(using encoder: PrimitiveEncoder[A]): Column[encoder.ColumnType] = Column(sql.functions.lit(encoder.encode(value)))

src/main/functions/when.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
package org.virtuslab.iskra
2+
package functions
3+
4+
export When.when

src/main/types/Coerce.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package org.virtuslab.iskra
2+
package types
3+
4+
import DataType.{CommonNumericNonNullableType, CommonNumericNullableType, NumericOptType, NumericType}
5+
6+
trait Coerce[-A <: DataType, -B <: DataType]:
7+
type Coerced <: DataType
8+
9+
object Coerce:
10+
given sameType[A <: DataType]: Coerce[A, A] with
11+
override type Coerced = A
12+
13+
given nullable[A <: NumericOptType, B <: NumericOptType]: Coerce[A, B] with
14+
override type Coerced = CommonNumericNullableType[A, B]
15+
16+
given nonNullable[A <: NumericType, B <: NumericType]: Coerce[A, B] with
17+
override type Coerced = CommonNumericNonNullableType[A, B]

src/main/types/DataType.scala

Lines changed: 1 addition & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
package org.virtuslab.iskra
22
package types
33

4-
import scala.quoted._
5-
import scala.deriving.Mirror
6-
import org.apache.spark.sql
7-
import MacroHelpers.TupleSubtype
8-
94
sealed trait DataType
105

116
object DataType:
@@ -38,7 +33,7 @@ object DataType:
3833
case DoubleOptType => DoubleOptType
3934
case StructOptType[schema] => StructOptType[schema]
4035

41-
type CommonNumericOptType[T1 <: DataType, T2 <: DataType] <: NumericOptType = (T1, T2) match
36+
type CommonNumericNullableType[T1 <: DataType, T2 <: DataType] <: NumericOptType = (T1, T2) match
4237
case (DoubleOptType, _) | (_, DoubleOptType) => DoubleOptType
4338
case (FloatOptType, _) | (_, FloatOptType) => FloatOptType
4439
case (LongOptType, _) | (_, LongOptType) => LongOptType
@@ -54,179 +49,6 @@ object DataType:
5449
case (ShortOptType, _) | (_, ShortOptType) => ShortType
5550
case (ByteOptType, _) | (_, ByteOptType) => ByteType
5651

57-
trait Encoder[-A]:
58-
type ColumnType <: DataType
59-
def encode(value: A): Any
60-
def decode(value: Any): Any
61-
def catalystType: sql.types.DataType
62-
def isNullable: Boolean
63-
64-
trait PrimitiveEncoder[-A] extends Encoder[A]
65-
66-
trait PrimitiveNullableEncoder[-A] extends PrimitiveEncoder[Option[A]]:
67-
def encode(value: Option[A]) = value.orNull
68-
def decode(value: Any) = Option(value)
69-
def isNullable = true
70-
71-
trait PrimitiveNonNullableEncoder[-A] extends PrimitiveEncoder[A]:
72-
def encode(value: A) = value
73-
def decode(value: Any) = value
74-
def isNullable = true
75-
76-
77-
object Encoder:
78-
type Aux[-A, E <: DataType] = Encoder[A] { type ColumnType = E }
79-
80-
inline given boolean: PrimitiveNonNullableEncoder[Boolean] with
81-
type ColumnType = BooleanType
82-
def catalystType = sql.types.BooleanType
83-
inline given booleanOpt: PrimitiveNullableEncoder[Boolean] with
84-
type ColumnType = BooleanOptType
85-
def catalystType = sql.types.BooleanType
86-
87-
inline given string: PrimitiveNonNullableEncoder[String] with
88-
type ColumnType = StringType
89-
def catalystType = sql.types.StringType
90-
inline given stringOpt: PrimitiveNullableEncoder[String] with
91-
type ColumnType = StringOptType
92-
def catalystType = sql.types.StringType
93-
94-
inline given byte: PrimitiveNonNullableEncoder[Byte] with
95-
type ColumnType = ByteType
96-
def catalystType = sql.types.ByteType
97-
inline given byteOpt: PrimitiveNullableEncoder[Byte] with
98-
type ColumnType = ByteOptType
99-
def catalystType = sql.types.ByteType
100-
101-
inline given short: PrimitiveNonNullableEncoder[Short] with
102-
type ColumnType = ShortType
103-
def catalystType = sql.types.ShortType
104-
inline given shortOpt: PrimitiveNullableEncoder[Short] with
105-
type ColumnType = ShortOptType
106-
def catalystType = sql.types.ShortType
107-
108-
inline given int: PrimitiveNonNullableEncoder[Int] with
109-
type ColumnType = IntegerType
110-
def catalystType = sql.types.IntegerType
111-
inline given intOpt: PrimitiveNullableEncoder[Int] with
112-
type ColumnType = IntegerOptType
113-
def catalystType = sql.types.IntegerType
114-
115-
inline given long: PrimitiveNonNullableEncoder[Long] with
116-
type ColumnType = LongType
117-
def catalystType = sql.types.LongType
118-
inline given longOpt: PrimitiveNullableEncoder[Long] with
119-
type ColumnType = LongOptType
120-
def catalystType = sql.types.LongType
121-
122-
inline given float: PrimitiveNonNullableEncoder[Float] with
123-
type ColumnType = FloatType
124-
def catalystType = sql.types.FloatType
125-
inline given floatOpt: PrimitiveNullableEncoder[Float] with
126-
type ColumnType = FloatOptType
127-
def catalystType = sql.types.FloatType
128-
129-
inline given double: PrimitiveNonNullableEncoder[Double] with
130-
type ColumnType = DoubleType
131-
def catalystType = sql.types.DoubleType
132-
inline given doubleOpt: PrimitiveNullableEncoder[Double] with
133-
type ColumnType = DoubleOptType
134-
def catalystType = sql.types.DoubleType
135-
136-
export StructEncoder.{fromMirror, optFromMirror}
137-
138-
trait StructEncoder[-A] extends Encoder[A]:
139-
type StructSchema <: Tuple
140-
type ColumnType = StructType[StructSchema]
141-
override def catalystType: sql.types.StructType
142-
override def encode(a: A): sql.Row
143-
144-
object StructEncoder:
145-
type Aux[-A, E <: Tuple] = StructEncoder[A] { type StructSchema = E }
146-
147-
private case class ColumnInfo(
148-
labelType: Type[?],
149-
labelValue: String,
150-
decodedType: Type[?],
151-
encoder: Expr[Encoder[?]]
152-
)
153-
154-
private def getColumnSchemaType(using quotes: Quotes)(subcolumnInfos: List[ColumnInfo]): Type[?] =
155-
subcolumnInfos match
156-
case Nil => Type.of[EmptyTuple]
157-
case info :: tail =>
158-
info.labelType match
159-
case '[Name.Subtype[lt]] =>
160-
info.encoder match
161-
case '{ ${encoder}: Encoder.Aux[tpe, DataType.Subtype[e]] } =>
162-
getColumnSchemaType(tail) match
163-
case '[TupleSubtype[tailType]] =>
164-
Type.of[(lt := e) *: tailType]
165-
166-
private def getSubcolumnInfos(elemLabels: Type[?], elemTypes: Type[?])(using Quotes): List[ColumnInfo] =
167-
import quotes.reflect.Select
168-
elemLabels match
169-
case '[EmptyTuple] => Nil
170-
case '[label *: labels] =>
171-
val labelValue = Type.valueOfConstant[label].get.toString
172-
elemTypes match
173-
case '[tpe *: types] =>
174-
Expr.summon[Encoder[tpe]] match
175-
case Some(encoderExpr) =>
176-
ColumnInfo(Type.of[label], labelValue, Type.of[tpe], encoderExpr) :: getSubcolumnInfos(Type.of[labels], Type.of[types])
177-
case _ => quotes.reflect.report.errorAndAbort(s"Could not summon encoder for ${Type.show[tpe]}")
178-
179-
transparent inline given fromMirror[A]: StructEncoder[A] = ${ fromMirrorImpl[A] }
180-
181-
def fromMirrorImpl[A : Type](using q: Quotes): Expr[StructEncoder[A]] =
182-
Expr.summon[Mirror.Of[A]].getOrElse(throw new Exception(s"Could not find Mirror when generating encoder for ${Type.show[A]}")) match
183-
case '{ ${mirror}: Mirror.ProductOf[A] { type MirroredElemLabels = elementLabels; type MirroredElemTypes = elementTypes } } =>
184-
val subcolumnInfos = getSubcolumnInfos(Type.of[elementLabels], Type.of[elementTypes])
185-
val columnSchemaType = getColumnSchemaType(subcolumnInfos)
186-
187-
val structFieldExprs = subcolumnInfos.map { info =>
188-
'{ sql.types.StructField(${Expr(info.labelValue)}, ${info.encoder}.catalystType, ${info.encoder}.isNullable) }
189-
}
190-
val structFields = Expr.ofSeq(structFieldExprs)
191-
192-
def rowElements(value: Expr[A]) =
193-
subcolumnInfos.map { info =>
194-
import quotes.reflect.*
195-
info.decodedType match
196-
case '[t] =>
197-
'{ ${info.encoder}.asInstanceOf[Encoder[t]].encode(${ Select.unique(value.asTerm, info.labelValue).asExprOf[t] }) }
198-
}
199-
200-
def rowElementsTuple(row: Expr[sql.Row]): Expr[Tuple] =
201-
val elements = subcolumnInfos.zipWithIndex.map { (info, idx) =>
202-
given Quotes = q
203-
'{ ${info.encoder}.decode(${row}.get(${Expr(idx)})) }
204-
}
205-
Expr.ofTupleFromSeq(elements)
206-
207-
columnSchemaType match
208-
case '[TupleSubtype[t]] =>
209-
'{
210-
(new StructEncoder[A] {
211-
override type StructSchema = t
212-
override def catalystType = sql.types.StructType(${ structFields })
213-
override def isNullable = false
214-
override def encode(a: A) =
215-
sql.Row.fromSeq(${ Expr.ofSeq(rowElements('a)) })
216-
override def decode(a: Any) =
217-
${mirror}.fromProduct(${ rowElementsTuple('{a.asInstanceOf[sql.Row]}) })
218-
}): StructEncoder[A] { type StructSchema = t }
219-
}
220-
end fromMirrorImpl
221-
222-
inline given optFromMirror[A](using encoder: StructEncoder[A]): (Encoder[Option[A]] { type ColumnType = StructOptType[encoder.StructSchema] }) =
223-
new Encoder[Option[A]]:
224-
override type ColumnType = StructOptType[encoder.StructSchema]
225-
override def encode(value: Option[A]): Any = value.map(encoder.encode).orNull
226-
override def decode(value: Any): Any = Option(encoder.decode)
227-
override def catalystType = encoder.catalystType
228-
override def isNullable = true
229-
23052
import DataType.NotNull
23153

23254
sealed class BooleanOptType extends DataType

0 commit comments

Comments
 (0)