Skip to content

Commit ef2f4ee

Browse files
committed
Register functions directly
this is a starting point
1 parent 8da8bd7 commit ef2f4ee

File tree

1 file changed

+164
-112
lines changed
  • core/src/main/scala/org/locationtech/rasterframes/expressions

1 file changed

+164
-112
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 164 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@
2222
package org.locationtech.rasterframes
2323

2424
import geotrellis.raster.{DoubleConstantNoDataCellType, Tile}
25-
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
25+
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase}
2626
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
27-
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
28-
import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
29-
import org.apache.spark.sql.rf.VersionShims._
27+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpressionInfo, ScalaUDF}
28+
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, ScalaReflection}
3029
import org.apache.spark.sql.types.DataType
31-
import org.apache.spark.sql.{SQLContext, rf}
30+
import org.apache.spark.sql.{SQLContext}
3231
import org.locationtech.rasterframes.expressions.accessors._
3332
import org.locationtech.rasterframes.expressions.aggregates.CellCountAggregate.DataCells
3433
import org.locationtech.rasterframes.expressions.aggregates._
@@ -38,6 +37,7 @@ import org.locationtech.rasterframes.expressions.focalops._
3837
import org.locationtech.rasterframes.expressions.tilestats._
3938
import org.locationtech.rasterframes.expressions.transformers._
4039

40+
import scala.reflect.ClassTag
4141
import scala.reflect.runtime.universe._
4242

4343
/**
@@ -64,114 +64,166 @@ package object expressions {
6464
def udfiexpr[RT: TypeTag, A1: TypeTag](name: String, f: DataType => A1 => RT): Expression => ScalaUDF = (child: Expression) => {
6565
val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[RT]
6666
ScalaUDF((row: A1) => f(child.dataType)(row), dataType, Seq(child), Seq(Option(ExpressionEncoder[RT]().resolveAndBind())), udfName = Some(name))
67+
68+
}
69+
70+
private def expressionInfo[T : ClassTag](name: String, since: Option[String], database: Option[String]): ExpressionInfo = {
71+
val clazz = scala.reflect.classTag[T].runtimeClass
72+
val df = clazz.getAnnotation(classOf[ExpressionDescription])
73+
if (df != null) {
74+
if (df.extended().isEmpty) {
75+
new ExpressionInfo(
76+
clazz.getCanonicalName,
77+
database.orNull,
78+
name,
79+
df.usage(),
80+
df.arguments(),
81+
df.examples(),
82+
df.note(),
83+
df.group(),
84+
since.getOrElse(df.since()),
85+
df.deprecated(),
86+
df.source())
87+
} else {
88+
// This exists for the backward compatibility with old `ExpressionDescription`s defining
89+
// the extended description in `extended()`.
90+
new ExpressionInfo(clazz.getCanonicalName, database.orNull, name, df.usage(), df.extended())
91+
}
92+
} else {
93+
new ExpressionInfo(clazz.getCanonicalName, name)
94+
}
6795
}
6896

69-
def register(sqlContext: SQLContext): Unit = {
70-
// Expression-oriented functions have a different registration scheme
71-
// Currently have to register with the `builtin` registry due to Spark data hiding.
72-
val registry: FunctionRegistry = rf.registry(sqlContext)
73-
74-
registry.registerExpression[Add]("rf_local_add")
75-
registry.registerExpression[Subtract]("rf_local_subtract")
76-
registry.registerExpression[TileAssembler]("rf_assemble_tile")
77-
registry.registerExpression[ExplodeTiles]("rf_explode_tiles")
78-
registry.registerExpression[GetCellType]("rf_cell_type")
79-
registry.registerExpression[SetCellType]("rf_convert_cell_type")
80-
registry.registerExpression[InterpretAs]("rf_interpret_cell_type_as")
81-
registry.registerExpression[SetNoDataValue]("rf_with_no_data")
82-
registry.registerExpression[GetDimensions]("rf_dimensions")
83-
registry.registerExpression[ExtentToGeometry]("st_geometry")
84-
registry.registerExpression[GetGeometry]("rf_geometry")
85-
registry.registerExpression[GeometryToExtent]("st_extent")
86-
registry.registerExpression[GetExtent]("rf_extent")
87-
registry.registerExpression[GetCRS]("rf_crs")
88-
registry.registerExpression[RealizeTile]("rf_tile")
89-
registry.registerExpression[CreateProjectedRaster]("rf_proj_raster")
90-
registry.registerExpression[Multiply]("rf_local_multiply")
91-
registry.registerExpression[Divide]("rf_local_divide")
92-
registry.registerExpression[NormalizedDifference]("rf_normalized_difference")
93-
registry.registerExpression[Less]("rf_local_less")
94-
registry.registerExpression[Greater]("rf_local_greater")
95-
registry.registerExpression[LessEqual]("rf_local_less_equal")
96-
registry.registerExpression[GreaterEqual]("rf_local_greater_equal")
97-
registry.registerExpression[Equal]("rf_local_equal")
98-
registry.registerExpression[Unequal]("rf_local_unequal")
99-
registry.registerExpression[IsIn]("rf_local_is_in")
100-
registry.registerExpression[Undefined]("rf_local_no_data")
101-
registry.registerExpression[Defined]("rf_local_data")
102-
registry.registerExpression[Min]("rf_local_min")
103-
registry.registerExpression[Max]("rf_local_max")
104-
registry.registerExpression[Clamp]("rf_local_clamp")
105-
registry.registerExpression[Where]("rf_where")
106-
registry.registerExpression[Standardize]("rf_standardize")
107-
registry.registerExpression[Rescale]("rf_rescale")
108-
registry.registerExpression[Sum]("rf_tile_sum")
109-
registry.registerExpression[Round]("rf_round")
110-
registry.registerExpression[Abs]("rf_abs")
111-
registry.registerExpression[Log]("rf_log")
112-
registry.registerExpression[Log10]("rf_log10")
113-
registry.registerExpression[Log2]("rf_log2")
114-
registry.registerExpression[Log1p]("rf_log1p")
115-
registry.registerExpression[Exp]("rf_exp")
116-
registry.registerExpression[Exp10]("rf_exp10")
117-
registry.registerExpression[Exp2]("rf_exp2")
118-
registry.registerExpression[ExpM1]("rf_expm1")
119-
registry.registerExpression[Sqrt]("rf_sqrt")
120-
registry.registerExpression[Resample]("rf_resample")
121-
registry.registerExpression[ResampleNearest]("rf_resample_nearest")
122-
registry.registerExpression[TileToArrayDouble]("rf_tile_to_array_double")
123-
registry.registerExpression[TileToArrayInt]("rf_tile_to_array_int")
124-
registry.registerExpression[DataCells]("rf_data_cells")
125-
registry.registerExpression[NoDataCells]("rf_no_data_cells")
126-
registry.registerExpression[IsNoDataTile]("rf_is_no_data_tile")
127-
registry.registerExpression[Exists]("rf_exists")
128-
registry.registerExpression[ForAll]("rf_for_all")
129-
registry.registerExpression[TileMin]("rf_tile_min")
130-
registry.registerExpression[TileMax]("rf_tile_max")
131-
registry.registerExpression[TileMean]("rf_tile_mean")
132-
registry.registerExpression[TileStats]("rf_tile_stats")
133-
registry.registerExpression[TileHistogram]("rf_tile_histogram")
134-
registry.registerExpression[DataCells]("rf_agg_data_cells")
135-
registry.registerExpression[CellCountAggregate.NoDataCells]("rf_agg_no_data_cells")
136-
registry.registerExpression[CellStatsAggregate.CellStatsAggregateUDAF]("rf_agg_stats")
137-
registry.registerExpression[HistogramAggregate.HistogramAggregateUDAF]("rf_agg_approx_histogram")
138-
registry.registerExpression[LocalStatsAggregate.LocalStatsAggregateUDAF]("rf_agg_local_stats")
139-
registry.registerExpression[LocalTileOpAggregate.LocalMinUDAF]("rf_agg_local_min")
140-
registry.registerExpression[LocalTileOpAggregate.LocalMaxUDAF]("rf_agg_local_max")
141-
registry.registerExpression[LocalCountAggregate.LocalDataCellsUDAF]("rf_agg_local_data_cells")
142-
registry.registerExpression[LocalCountAggregate.LocalNoDataCellsUDAF]("rf_agg_local_no_data_cells")
143-
registry.registerExpression[LocalMeanAggregate]("rf_agg_local_mean")
144-
145-
registry.registerExpression[FocalMax](FocalMax.name)
146-
registry.registerExpression[FocalMin](FocalMin.name)
147-
registry.registerExpression[FocalMean](FocalMean.name)
148-
registry.registerExpression[FocalMode](FocalMode.name)
149-
registry.registerExpression[FocalMedian](FocalMedian.name)
150-
registry.registerExpression[FocalMoransI](FocalMoransI.name)
151-
registry.registerExpression[FocalStdDev](FocalStdDev.name)
152-
registry.registerExpression[Convolve](Convolve.name)
153-
154-
registry.registerExpression[Slope](Slope.name)
155-
registry.registerExpression[Aspect](Aspect.name)
156-
registry.registerExpression[Hillshade](Hillshade.name)
157-
158-
registry.registerExpression[Mask.MaskByDefined]("rf_mask")
159-
registry.registerExpression[Mask.InverseMaskByDefined]("rf_inverse_mask")
160-
registry.registerExpression[Mask.MaskByValue]("rf_mask_by_value")
161-
registry.registerExpression[Mask.InverseMaskByValue]("rf_inverse_mask_by_value")
162-
registry.registerExpression[Mask.MaskByValues]("rf_mask_by_values")
163-
164-
registry.registerExpression[DebugRender.RenderAscii]("rf_render_ascii")
165-
registry.registerExpression[DebugRender.RenderMatrix]("rf_render_matrix")
166-
registry.registerExpression[RenderPNG.RenderCompositePNG]("rf_render_png")
167-
registry.registerExpression[RGBComposite]("rf_rgb_composite")
168-
169-
registry.registerExpression[XZ2Indexer]("rf_xz2_index")
170-
registry.registerExpression[Z2Indexer]("rf_z2_index")
171-
172-
registry.registerExpression[transformers.ReprojectGeometry]("st_reproject")
173-
174-
registry.registerExpression[ExtractBits]("rf_local_extract_bits")
175-
registry.registerExpression[ExtractBits]("rf_local_extract_bit")
97+
def register(sqlContext: SQLContext, database: Option[String] = None): Unit = {
98+
val registry = sqlContext.sparkSession.sessionState.functionRegistry
99+
100+
def registerFunction[T <: Expression : ClassTag](name: String, since: Option[String] = None)(builder: Seq[Expression] => T): Unit = {
101+
val id = FunctionIdentifier(name, database)
102+
val info = FunctionRegistryBase.expressionInfo[T](name, since)
103+
registry.registerFunction(id, info, builder)
104+
}
105+
106+
def register1[T <: Expression : ClassTag](
107+
name: String,
108+
builder: Expression => T
109+
): Unit = registerFunction[T](name, None){ case Seq(a) => builder(a)
110+
}
111+
112+
def register2[T <: Expression : ClassTag](
113+
name: String,
114+
builder: (Expression, Expression) => T
115+
): Unit = registerFunction[T](name, None){ case Seq(a, b) => builder(a, b) }
116+
117+
def register3[T <: Expression : ClassTag](
118+
name: String,
119+
builder: (Expression, Expression, Expression) => T
120+
): Unit = registerFunction[T](name, None){ case Seq(a, b, c) => builder(a, b, c) }
121+
122+
def register5[T <: Expression : ClassTag](
123+
name: String,
124+
builder: (Expression, Expression, Expression, Expression, Expression) => T
125+
): Unit = registerFunction[T](name, None){ case Seq(a, b, c, d, e) => builder(a, b, c, d, e) }
126+
127+
register2("rf_local_add", Add(_, _))
128+
register2("rf_local_subtract", Subtract(_, _))
129+
registerFunction("rf_explode_tiles"){ExplodeTiles(1.0, None, _)}
130+
register5("rf_assemble_tile", TileAssembler(_, _, _, _, _))
131+
register1("rf_cell_type", GetCellType(_))
132+
register2("rf_convert_cell_type", SetCellType(_, _))
133+
register2("rf_interpret_cell_type_as", InterpretAs(_, _))
134+
register2("rf_with_no_data", SetNoDataValue(_,_))
135+
register1("rf_dimensions", GetDimensions(_))
136+
register1("st_geometry", ExtentToGeometry(_))
137+
register1("rf_geometry", GetGeometry(_))
138+
register1("st_extent", GeometryToExtent(_))
139+
register1("rf_extent", GetExtent(_))
140+
register1("rf_crs", GetCRS(_))
141+
register1("rf_tile", RealizeTile(_))
142+
register3("rf_proj_raster", CreateProjectedRaster(_, _, _))
143+
register2("rf_local_multiply", Multiply(_, _))
144+
register2("rf_local_divide", Divide(_, _))
145+
register2("rf_normalized_difference", NormalizedDifference(_,_))
146+
register2("rf_local_less", Less(_, _))
147+
register2("rf_local_greater", Greater(_, _))
148+
register2("rf_local_less_equal", LessEqual(_, _))
149+
register2("rf_local_greater_equal", GreaterEqual(_, _))
150+
register2("rf_local_equal", Equal(_, _))
151+
register2("rf_local_unequal", Unequal(_, _))
152+
register2("rf_local_is_in", IsIn(_, _))
153+
register1("rf_local_no_data", Undefined(_))
154+
register1("rf_local_data", Defined(_))
155+
register2("rf_local_min", Min(_, _))
156+
register2("rf_local_max", Max(_, _))
157+
register3("rf_local_clamp", Clamp(_, _, _))
158+
register3("rf_where", Where(_, _, _))
159+
register3("rf_standardize", Standardize(_, _, _))
160+
register3("rf_rescale", Rescale(_, _ , _))
161+
register1("rf_tile_sum", Sum(_))
162+
register1("rf_round", Round(_))
163+
register1("rf_abs", Abs(_))
164+
register1("rf_log", Log(_))
165+
register1("rf_log10", Log10(_))
166+
register1("rf_log2", Log2(_))
167+
register1("rf_log1p", Log1p(_))
168+
register1("rf_exp", Exp(_))
169+
register1("rf_exp10", Exp10(_))
170+
register1("rf_exp2", Exp2(_))
171+
register1("rf_expm1", ExpM1(_))
172+
register1("rf_sqrt", Sqrt(_))
173+
register3("rf_resample", Resample(_, _, _))
174+
register2("rf_resample_nearest", ResampleNearest(_, _))
175+
register1("rf_tile_to_array_double", TileToArrayDouble(_))
176+
register1("rf_tile_to_array_int", TileToArrayInt(_))
177+
register1("rf_data_cells", DataCells(_))
178+
register1("rf_no_data_cells", NoDataCells(_))
179+
register1("rf_is_no_data_tile", IsNoDataTile(_))
180+
register1("rf_exists", Exists(_))
181+
register1("rf_for_all", ForAll(_))
182+
register1("rf_tile_min", TileMin(_))
183+
register1("rf_tile_max", TileMax(_))
184+
register1("rf_tile_mean", TileMean(_))
185+
register1("rf_tile_stats", TileStats(_))
186+
register1("rf_tile_histogram", TileHistogram(_))
187+
register1("rf_agg_data_cells", DataCells(_))
188+
register1("rf_agg_no_data_cells", CellCountAggregate.NoDataCells(_))
189+
register1("rf_agg_stats", CellStatsAggregate.CellStatsAggregateUDAF(_))
190+
register1("rf_agg_approx_histogram", HistogramAggregate.HistogramAggregateUDAF(_))
191+
register1("rf_agg_local_stats", LocalStatsAggregate.LocalStatsAggregateUDAF(_))
192+
register1("rf_agg_local_min",LocalTileOpAggregate.LocalMinUDAF(_))
193+
register1("rf_agg_local_max", LocalTileOpAggregate.LocalMaxUDAF(_))
194+
register1("rf_agg_local_data_cells", LocalCountAggregate.LocalDataCellsUDAF(_))
195+
register1("rf_agg_local_no_data_cells", LocalCountAggregate.LocalNoDataCellsUDAF(_))
196+
register1("rf_agg_local_mean", LocalMeanAggregate(_))
197+
register3(FocalMax.name, FocalMax(_, _, _))
198+
register3(FocalMin.name, FocalMin(_, _, _))
199+
register3(FocalMean.name, FocalMean(_, _, _))
200+
register3(FocalMode.name, FocalMode(_, _, _))
201+
register3(FocalMedian.name, FocalMedian(_, _, _))
202+
register3(FocalMoransI.name, FocalMoransI(_, _, _))
203+
register3(FocalStdDev.name, FocalStdDev(_, _, _))
204+
register3(Convolve.name, Convolve(_, _, _))
205+
206+
register3(Slope.name, Slope(_, _, _))
207+
register2(Aspect.name, Aspect(_, _))
208+
register5(Hillshade.name, Hillshade(_, _, _, _, _))
209+
210+
register2("rf_mask", Mask.MaskByDefined(_, _))
211+
register2("rf_inverse_mask", Mask.InverseMaskByDefined(_, _))
212+
register3("rf_mask_by_value", Mask.MaskByValue(_, _, _))
213+
register3("rf_inverse_mask_by_value", Mask.InverseMaskByValue(_, _, _))
214+
register2("rf_mask_by_values", Mask.MaskByValues(_, _))
215+
216+
register1("rf_render_ascii", DebugRender.RenderAscii(_))
217+
register1("rf_render_matrix", DebugRender.RenderMatrix(_))
218+
register1("rf_render_png", RenderPNG.RenderCompositePNG(_))
219+
register3("rf_rgb_composite", RGBComposite(_, _, _))
220+
221+
register2("rf_xz2_index", XZ2Indexer(_, _, 18.toShort))
222+
register2("rf_z2_index", Z2Indexer(_, _, 31.toShort))
223+
224+
register3("st_reproject", ReprojectGeometry(_, _, _))
225+
226+
register3[ExtractBits]("rf_local_extract_bits", ExtractBits(_: Expression, _: Expression, _: Expression))
227+
register3[ExtractBits]("rf_local_extract_bit", ExtractBits(_: Expression, _: Expression, _: Expression))
176228
}
177229
}

0 commit comments

Comments
 (0)