Skip to content

Commit d1cfb99

Browse files
committed
Expressions constructors toSeq conversion
1 parent 43e8d3d commit d1cfb99

File tree

1 file changed

+115
-151
lines changed
  • core/src/main/scala/org/locationtech/rasterframes/expressions

1 file changed

+115
-151
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 115 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ package org.locationtech.rasterframes
2424
import geotrellis.raster.{DoubleConstantNoDataCellType, Tile}
2525
import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
2626
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
27-
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpressionInfo, ScalaUDF}
27+
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
2828
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, ScalaReflection}
2929
import org.apache.spark.sql.types.DataType
3030
import org.apache.spark.sql.SQLContext
@@ -36,19 +36,23 @@ import org.locationtech.rasterframes.expressions.localops._
3636
import org.locationtech.rasterframes.expressions.focalops._
3737
import org.locationtech.rasterframes.expressions.tilestats._
3838
import org.locationtech.rasterframes.expressions.transformers._
39+
import shapeless.HList
40+
import shapeless.ops.function.FnToProduct
41+
import shapeless.ops.traversable.FromTraversable
3942

4043
import scala.reflect.ClassTag
4144
import scala.reflect.runtime.universe._
45+
import scala.language.implicitConversions
4246

4347
/**
4448
* Module of Catalyst expressions for efficiently working with tiles.
4549
*
4650
* @since 10/10/17
4751
*/
4852
package object expressions {
49-
type HasTernaryExpressionCopy = {def copy(first: Expression, second: Expression, third: Expression): Expression}
50-
type HasBinaryExpressionCopy = {def copy(left: Expression, right: Expression): Expression}
51-
type HasUnaryExpressionCopy = {def copy(child: Expression): Expression}
53+
type HasTernaryExpressionCopy = { def copy(first: Expression, second: Expression, third: Expression): Expression }
54+
type HasBinaryExpressionCopy = { def copy(left: Expression, right: Expression): Expression }
55+
type HasUnaryExpressionCopy = { def copy(child: Expression): Expression }
5256

5357
private[expressions] def row(input: Any) = input.asInstanceOf[InternalRow]
5458
/** Convert the tile to a floating point type as needed for scalar operations. */
@@ -67,33 +71,6 @@ package object expressions {
6771

6872
}
6973

70-
private def expressionInfo[T : ClassTag](name: String, since: Option[String], database: Option[String]): ExpressionInfo = {
71-
val clazz = scala.reflect.classTag[T].runtimeClass
72-
val df = clazz.getAnnotation(classOf[ExpressionDescription])
73-
if (df != null) {
74-
if (df.extended().isEmpty) {
75-
new ExpressionInfo(
76-
clazz.getCanonicalName,
77-
database.orNull,
78-
name,
79-
df.usage(),
80-
df.arguments(),
81-
df.examples(),
82-
df.note(),
83-
df.group(),
84-
since.getOrElse(df.since()),
85-
df.deprecated(),
86-
df.source())
87-
} else {
88-
// This exists for the backward compatibility with old `ExpressionDescription`s defining
89-
// the extended description in `extended()`.
90-
new ExpressionInfo(clazz.getCanonicalName, database.orNull, name, df.usage(), df.extended())
91-
}
92-
} else {
93-
new ExpressionInfo(clazz.getCanonicalName, name)
94-
}
95-
}
96-
9774
def register(sqlContext: SQLContext, database: Option[String] = None): Unit = {
9875
val registry = sqlContext.sparkSession.sessionState.functionRegistry
9976

@@ -103,127 +80,114 @@ package object expressions {
10380
registry.registerFunction(id, info, builder)
10481
}
10582

106-
def register1[T <: Expression : ClassTag](
107-
name: String,
108-
builder: Expression => T
109-
): Unit = registerFunction[T](name, None){ args => builder(args(0))
83+
/** Converts (expr1: Expression, ..., exprn: Expression) => R into a Seq[Expression] => R function */
84+
implicit def expressionArgumentsSequencer[F, I <: HList, R](f: F)(implicit ftp: FnToProduct.Aux[F, I => R], ft: FromTraversable[I]): Seq[Expression] => R = { list: Seq[Expression] =>
85+
ft(list) match {
86+
case Some(l) => ftp(f)(l)
87+
case None => throw new IllegalArgumentException(s"registerFunction application failed: arity mismatch: $list.")
88+
}
11089
}
11190

112-
def register2[T <: Expression : ClassTag](
113-
name: String,
114-
builder: (Expression, Expression) => T
115-
): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1)) }
116-
117-
def register3[T <: Expression : ClassTag](
118-
name: String,
119-
builder: (Expression, Expression, Expression) => T
120-
): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1), args(2)) }
121-
122-
def register5[T <: Expression : ClassTag](
123-
name: String,
124-
builder: (Expression, Expression, Expression, Expression, Expression) => T
125-
): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1), args(2), args(3), args(4)) }
126-
127-
register2("rf_local_add", Add(_, _))
128-
register2("rf_local_subtract", Subtract(_, _))
129-
registerFunction("rf_explode_tiles"){ExplodeTiles(1.0, None, _)}
130-
register5("rf_assemble_tile", TileAssembler(_, _, _, _, _))
131-
register1("rf_cell_type", GetCellType(_))
132-
register2("rf_convert_cell_type", SetCellType(_, _))
133-
register2("rf_interpret_cell_type_as", InterpretAs(_, _))
134-
register2("rf_with_no_data", SetNoDataValue(_,_))
135-
register1("rf_dimensions", GetDimensions(_))
136-
register1("st_geometry", ExtentToGeometry(_))
137-
register1("rf_geometry", GetGeometry(_))
138-
register1("st_extent", GeometryToExtent(_))
139-
register1("rf_extent", GetExtent(_))
140-
register1("rf_crs", GetCRS(_))
141-
register1("rf_tile", RealizeTile(_))
142-
register3("rf_proj_raster", CreateProjectedRaster(_, _, _))
143-
register2("rf_local_multiply", Multiply(_, _))
144-
register2("rf_local_divide", Divide(_, _))
145-
register2("rf_normalized_difference", NormalizedDifference(_,_))
146-
register2("rf_local_less", Less(_, _))
147-
register2("rf_local_greater", Greater(_, _))
148-
register2("rf_local_less_equal", LessEqual(_, _))
149-
register2("rf_local_greater_equal", GreaterEqual(_, _))
150-
register2("rf_local_equal", Equal(_, _))
151-
register2("rf_local_unequal", Unequal(_, _))
152-
register2("rf_local_is_in", IsIn(_, _))
153-
register1("rf_local_no_data", Undefined(_))
154-
register1("rf_local_data", Defined(_))
155-
register2("rf_local_min", Min(_, _))
156-
register2("rf_local_max", Max(_, _))
157-
register3("rf_local_clamp", Clamp(_, _, _))
158-
register3("rf_where", Where(_, _, _))
159-
register3("rf_standardize", Standardize(_, _, _))
160-
register3("rf_rescale", Rescale(_, _ , _))
161-
register1("rf_tile_sum", Sum(_))
162-
register1("rf_round", Round(_))
163-
register1("rf_abs", Abs(_))
164-
register1("rf_log", Log(_))
165-
register1("rf_log10", Log10(_))
166-
register1("rf_log2", Log2(_))
167-
register1("rf_log1p", Log1p(_))
168-
register1("rf_exp", Exp(_))
169-
register1("rf_exp10", Exp10(_))
170-
register1("rf_exp2", Exp2(_))
171-
register1("rf_expm1", ExpM1(_))
172-
register1("rf_sqrt", Sqrt(_))
173-
register3("rf_resample", Resample(_, _, _))
174-
register2("rf_resample_nearest", ResampleNearest(_, _))
175-
register1("rf_tile_to_array_double", TileToArrayDouble(_))
176-
register1("rf_tile_to_array_int", TileToArrayInt(_))
177-
register1("rf_data_cells", DataCells(_))
178-
register1("rf_no_data_cells", NoDataCells(_))
179-
register1("rf_is_no_data_tile", IsNoDataTile(_))
180-
register1("rf_exists", Exists(_))
181-
register1("rf_for_all", ForAll(_))
182-
register1("rf_tile_min", TileMin(_))
183-
register1("rf_tile_max", TileMax(_))
184-
register1("rf_tile_mean", TileMean(_))
185-
register1("rf_tile_stats", TileStats(_))
186-
register1("rf_tile_histogram", TileHistogram(_))
187-
register1("rf_agg_data_cells", DataCells(_))
188-
register1("rf_agg_no_data_cells", CellCountAggregate.NoDataCells(_))
189-
register1("rf_agg_stats", CellStatsAggregate.CellStatsAggregateUDAF(_))
190-
register1("rf_agg_approx_histogram", HistogramAggregate.HistogramAggregateUDAF(_))
191-
register1("rf_agg_local_stats", LocalStatsAggregate.LocalStatsAggregateUDAF(_))
192-
register1("rf_agg_local_min",LocalTileOpAggregate.LocalMinUDAF(_))
193-
register1("rf_agg_local_max", LocalTileOpAggregate.LocalMaxUDAF(_))
194-
register1("rf_agg_local_data_cells", LocalCountAggregate.LocalDataCellsUDAF(_))
195-
register1("rf_agg_local_no_data_cells", LocalCountAggregate.LocalNoDataCellsUDAF(_))
196-
register1("rf_agg_local_mean", LocalMeanAggregate(_))
197-
register3(FocalMax.name, FocalMax(_, _, _))
198-
register3(FocalMin.name, FocalMin(_, _, _))
199-
register3(FocalMean.name, FocalMean(_, _, _))
200-
register3(FocalMode.name, FocalMode(_, _, _))
201-
register3(FocalMedian.name, FocalMedian(_, _, _))
202-
register3(FocalMoransI.name, FocalMoransI(_, _, _))
203-
register3(FocalStdDev.name, FocalStdDev(_, _, _))
204-
register3(Convolve.name, Convolve(_, _, _))
205-
206-
register3(Slope.name, Slope(_, _, _))
207-
register2(Aspect.name, Aspect(_, _))
208-
register5(Hillshade.name, Hillshade(_, _, _, _, _))
209-
210-
register2("rf_mask", MaskByDefined(_, _))
211-
register2("rf_inverse_mask", InverseMaskByDefined(_, _))
212-
register3("rf_mask_by_value", MaskByValue(_, _, _))
213-
register3("rf_inverse_mask_by_value", InverseMaskByValue(_, _, _))
214-
register3("rf_mask_by_values", MaskByValues(_, _, _))
215-
216-
register1("rf_render_ascii", DebugRender.RenderAscii(_))
217-
register1("rf_render_matrix", DebugRender.RenderMatrix(_))
218-
register1("rf_render_png", RenderPNG.RenderCompositePNG(_))
219-
register3("rf_rgb_composite", RGBComposite(_, _, _))
220-
221-
register2("rf_xz2_index", XZ2Indexer(_, _, 18.toShort))
222-
register2("rf_z2_index", Z2Indexer(_, _, 31.toShort))
223-
224-
register3("st_reproject", ReprojectGeometry(_, _, _))
225-
226-
register3[ExtractBits]("rf_local_extract_bits", ExtractBits(_: Expression, _: Expression, _: Expression))
227-
register3[ExtractBits]("rf_local_extract_bit", ExtractBits(_: Expression, _: Expression, _: Expression))
91+
registerFunction[Add](name = "rf_local_add")(Add.apply)
92+
registerFunction[Subtract](name = "rf_local_subtract")(Subtract.apply)
93+
registerFunction[ExplodeTiles](name = "rf_explode_tiles")(ExplodeTiles(1.0, None, _))
94+
registerFunction[TileAssembler](name = "rf_assemble_tile")(TileAssembler.apply)
95+
registerFunction[GetCellType](name = "rf_cell_type")(GetCellType.apply)
96+
registerFunction[SetCellType](name = "rf_convert_cell_type")(SetCellType.apply)
97+
registerFunction[InterpretAs](name = "rf_interpret_cell_type_as")(InterpretAs.apply)
98+
registerFunction[SetNoDataValue](name = "rf_with_no_data")(SetNoDataValue.apply)
99+
registerFunction[GetDimensions](name = "rf_dimensions")(GetDimensions.apply)
100+
registerFunction[ExtentToGeometry](name = "st_geometry")(ExtentToGeometry.apply)
101+
registerFunction[GetGeometry](name = "rf_geometry")(GetGeometry.apply)
102+
registerFunction[GeometryToExtent](name = "st_extent")(GeometryToExtent.apply)
103+
registerFunction[GetExtent](name = "rf_extent")(GetExtent.apply)
104+
registerFunction[GetCRS](name = "rf_crs")(GetCRS.apply)
105+
registerFunction[RealizeTile](name = "rf_tile")(RealizeTile.apply)
106+
registerFunction[CreateProjectedRaster](name = "rf_proj_raster")(CreateProjectedRaster.apply)
107+
registerFunction[Multiply](name = "rf_local_multiply")(Multiply.apply)
108+
registerFunction[Divide](name = "rf_local_divide")(Divide.apply)
109+
registerFunction[NormalizedDifference](name = "rf_normalized_difference")(NormalizedDifference.apply)
110+
registerFunction[Less](name = "rf_local_less")(Less.apply)
111+
registerFunction[Greater](name = "rf_local_greater")(Greater.apply)
112+
registerFunction[LessEqual](name = "rf_local_less_equal")(LessEqual.apply)
113+
registerFunction[GreaterEqual](name = "rf_local_greater_equal")(GreaterEqual.apply)
114+
registerFunction[Equal](name = "rf_local_equal")(Equal.apply)
115+
registerFunction[Unequal](name = "rf_local_unequal")(Unequal.apply)
116+
registerFunction[IsIn](name = "rf_local_is_in")(IsIn.apply)
117+
registerFunction[Undefined](name = "rf_local_no_data")(Undefined.apply)
118+
registerFunction[Defined](name = "rf_local_data")(Defined.apply)
119+
registerFunction[Min](name = "rf_local_min")(Min.apply)
120+
registerFunction[Max](name = "rf_local_max")(Max.apply)
121+
registerFunction[Clamp](name = "rf_local_clamp")(Clamp.apply)
122+
registerFunction[Where](name = "rf_where")(Where.apply)
123+
registerFunction[Standardize](name = "rf_standardize")(Standardize.apply)
124+
registerFunction[Rescale](name = "rf_rescale")(Rescale.apply)
125+
registerFunction[Sum](name = "rf_tile_sum")(Sum.apply)
126+
registerFunction[Round](name = "rf_round")(Round.apply)
127+
registerFunction[Abs](name = "rf_abs")(Abs.apply)
128+
registerFunction[Log](name = "rf_log")(Log.apply)
129+
registerFunction[Log10](name = "rf_log10")(Log10.apply)
130+
registerFunction[Log2](name = "rf_log2")(Log2.apply)
131+
registerFunction[Log1p](name = "rf_log1p")(Log1p.apply)
132+
registerFunction[Exp](name = "rf_exp")(Exp.apply)
133+
registerFunction[Exp10](name = "rf_exp10")(Exp10.apply)
134+
registerFunction[Exp2](name = "rf_exp2")(Exp2.apply)
135+
registerFunction[ExpM1](name = "rf_expm1")(ExpM1.apply)
136+
registerFunction[Sqrt](name = "rf_sqrt")(Sqrt.apply)
137+
registerFunction[Resample](name = "rf_resample")(Resample.apply)
138+
registerFunction[ResampleNearest](name = "rf_resample_nearest")(ResampleNearest.apply)
139+
registerFunction[TileToArrayDouble](name = "rf_tile_to_array_double")(TileToArrayDouble.apply)
140+
registerFunction[TileToArrayInt](name = "rf_tile_to_array_int")(TileToArrayInt.apply)
141+
registerFunction[DataCells](name = "rf_data_cells")(DataCells.apply)
142+
registerFunction[NoDataCells](name = "rf_no_data_cells")(NoDataCells.apply)
143+
registerFunction[IsNoDataTile](name = "rf_is_no_data_tile")(IsNoDataTile.apply)
144+
registerFunction[Exists](name = "rf_exists")(Exists.apply)
145+
registerFunction[ForAll](name = "rf_for_all")(ForAll.apply)
146+
registerFunction[TileMin](name = "rf_tile_min")(TileMin.apply)
147+
registerFunction[TileMax](name = "rf_tile_max")(TileMax.apply)
148+
registerFunction[TileMean](name = "rf_tile_mean")(TileMean.apply)
149+
registerFunction[TileStats](name = "rf_tile_stats")(TileStats.apply)
150+
registerFunction[TileHistogram](name = "rf_tile_histogram")(TileHistogram.apply)
151+
registerFunction[DataCells](name = "rf_agg_data_cells")(DataCells.apply)
152+
registerFunction[CellCountAggregate.NoDataCells](name = "rf_agg_no_data_cells")(CellCountAggregate.NoDataCells.apply)
153+
registerFunction[CellStatsAggregate.CellStatsAggregateUDAF](name = "rf_agg_stats")(CellStatsAggregate.CellStatsAggregateUDAF.apply)
154+
registerFunction[HistogramAggregate.HistogramAggregateUDAF](name = "rf_agg_approx_histogram")(HistogramAggregate.HistogramAggregateUDAF.apply)
155+
registerFunction[LocalStatsAggregate.LocalStatsAggregateUDAF](name = "rf_agg_local_stats")(LocalStatsAggregate.LocalStatsAggregateUDAF.apply)
156+
registerFunction[LocalTileOpAggregate.LocalMinUDAF](name = "rf_agg_local_min")(LocalTileOpAggregate.LocalMinUDAF.apply)
157+
registerFunction[LocalTileOpAggregate.LocalMaxUDAF](name = "rf_agg_local_max")(LocalTileOpAggregate.LocalMaxUDAF.apply)
158+
registerFunction[LocalCountAggregate.LocalDataCellsUDAF](name = "rf_agg_local_data_cells")(LocalCountAggregate.LocalDataCellsUDAF.apply)
159+
registerFunction[LocalCountAggregate.LocalNoDataCellsUDAF](name = "rf_agg_local_no_data_cells")(LocalCountAggregate.LocalNoDataCellsUDAF.apply)
160+
registerFunction[LocalMeanAggregate](name = "rf_agg_local_mean")(LocalMeanAggregate.apply)
161+
registerFunction[FocalMax](FocalMax.name)(FocalMax.apply)
162+
registerFunction[FocalMin](FocalMin.name)(FocalMin.apply)
163+
registerFunction[FocalMean](FocalMean.name)(FocalMean.apply)
164+
registerFunction[FocalMode](FocalMode.name)(FocalMode.apply)
165+
registerFunction[FocalMedian](FocalMedian.name)(FocalMedian.apply)
166+
registerFunction[FocalMoransI](FocalMoransI.name)(FocalMoransI.apply)
167+
registerFunction[FocalStdDev](FocalStdDev.name)(FocalStdDev.apply)
168+
registerFunction[Convolve](Convolve.name)(Convolve.apply)
169+
170+
registerFunction[Slope](Slope.name)(Slope.apply)
171+
registerFunction[Aspect](Aspect.name)(Aspect.apply)
172+
registerFunction[Hillshade](Hillshade.name)(Hillshade.apply)
173+
174+
registerFunction[MaskByDefined](name = "rf_mask")(MaskByDefined.apply)
175+
registerFunction[InverseMaskByDefined](name = "rf_inverse_mask")(InverseMaskByDefined.apply)
176+
registerFunction[MaskByValue](name = "rf_mask_by_value")(MaskByValue.apply)
177+
registerFunction[InverseMaskByValue](name = "rf_inverse_mask_by_value")(InverseMaskByValue.apply)
178+
registerFunction[MaskByValues](name = "rf_mask_by_values")(MaskByValues.apply)
179+
180+
registerFunction[DebugRender.RenderAscii](name = "rf_render_ascii")(DebugRender.RenderAscii.apply)
181+
registerFunction[DebugRender.RenderMatrix](name = "rf_render_matrix")(DebugRender.RenderMatrix.apply)
182+
registerFunction[RenderPNG.RenderCompositePNG](name = "rf_render_png")(RenderPNG.RenderCompositePNG.apply)
183+
registerFunction[RGBComposite](name = "rf_rgb_composite")(RGBComposite.apply)
184+
185+
registerFunction[XZ2Indexer](name = "rf_xz2_index")(XZ2Indexer(_: Expression, _: Expression, 18.toShort))
186+
registerFunction[Z2Indexer](name = "rf_z2_index")(Z2Indexer(_: Expression, _: Expression, 31.toShort))
187+
188+
registerFunction[ReprojectGeometry](name = "st_reproject")(ReprojectGeometry.apply)
189+
190+
registerFunction[ExtractBits]("rf_local_extract_bits")(ExtractBits.apply)
191+
registerFunction[ExtractBits]("rf_local_extract_bit")(ExtractBits.apply)
228192
}
229193
}

0 commit comments

Comments
 (0)